Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion providers/apache/kafka/docs/operators/index.rst
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
.. Licensed to the Apache Software Foundation (ASF) under one
.. Licensed to the Apache Software Foundation (ASF) under one
or more contributor license agreements. See the NOTICE file
distributed with this work for additional information
regarding copyright ownership. The ASF licenses this file
Expand Down Expand Up @@ -29,6 +29,9 @@ The operator creates a Kafka Consumer that reads a batch of messages from the cl

For parameter definitions take a look at :class:`~airflow.providers.apache.kafka.operators.consume.ConsumeFromTopicOperator`.

.. important::
If you set the ``commit_cadence`` parameter, ensure that the ``enable.auto.commit`` option in the Kafka connection configuration is explicitly set to ``false``.
By default, ``enable.auto.commit`` is ``true``, which causes the consumer to auto-commit offsets every 5 seconds, potentially overriding the behavior defined by ``commit_cadence``.

Using the operator
""""""""""""""""""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from __future__ import annotations

from collections.abc import Callable, Sequence
from functools import partial
from functools import cached_property, partial
from typing import Any

from airflow.exceptions import AirflowException
Expand Down Expand Up @@ -82,7 +82,7 @@ def __init__(
apply_function_batch: Callable[..., Any] | str | None = None,
apply_function_args: Sequence[Any] | None = None,
apply_function_kwargs: dict[Any, Any] | None = None,
commit_cadence: str | None = "end_of_operator",
commit_cadence: str = "end_of_operator",
max_messages: int | None = None,
max_batch_size: int = 1000,
poll_timeout: float = 60,
Expand All @@ -102,11 +102,7 @@ def __init__(
self.poll_timeout = poll_timeout

self.read_to_end = self.max_messages is None

if self.commit_cadence not in VALID_COMMIT_CADENCE:
raise AirflowException(
f"commit_cadence must be one of {VALID_COMMIT_CADENCE}. Got {self.commit_cadence}"
)
self._validate_commit_cadence_on_construct()

if self.max_messages is not None and self.max_batch_size > self.max_messages:
self.log.warning(
Expand All @@ -117,16 +113,19 @@ def __init__(
)
self.max_messages = self.max_batch_size

if self.commit_cadence == "never":
self.commit_cadence = None

if apply_function and apply_function_batch:
raise AirflowException(
"One of apply_function or apply_function_batch must be supplied, not both."
)

@cached_property
def hook(self):
"""Return the KafkaConsumerHook instance."""
return KafkaConsumerHook(topics=self.topics, kafka_config_id=self.kafka_config_id)

def execute(self, context) -> Any:
consumer = KafkaConsumerHook(topics=self.topics, kafka_config_id=self.kafka_config_id).get_consumer()
self._validate_commit_cadence_before_execute()
consumer = self.hook.get_consumer()

if isinstance(self.apply_function, str):
self.apply_function = import_string(self.apply_function)
Expand Down Expand Up @@ -177,10 +176,33 @@ def execute(self, context) -> Any:
self.log.info("committing offset at %s", self.commit_cadence)
consumer.commit()

if self.commit_cadence:
if self.commit_cadence != "never":
self.log.info("committing offset at %s", self.commit_cadence)
consumer.commit()

consumer.close()

return

def _validate_commit_cadence_on_construct(self):
"""Validate the commit_cadence parameter when the operator is constructed."""
if self.commit_cadence and self.commit_cadence not in VALID_COMMIT_CADENCE:
raise AirflowException(
f"commit_cadence must be one of {VALID_COMMIT_CADENCE}. Got {self.commit_cadence}"
)

def _validate_commit_cadence_before_execute(self):
"""Validate the commit_cadence parameter before executing the operator."""
kafka_config = self.hook.get_connection(self.kafka_config_id).extra_dejson
# Same as kafka's behavior, default to "true" if not set
enable_auto_commit = str(kafka_config.get("enable.auto.commit", "true")).lower()

if self.commit_cadence and enable_auto_commit != "false":
self.log.warning(
"To respect commit_cadence='%s', "
"'enable.auto.commit' should be set to 'false' in the Kafka connection configuration. "
"Currently, 'enable.auto.commit' is not explicitly set, so it defaults to 'true', which causes "
"the consumer to auto-commit offsets every 5 seconds. "
"See: https://kafka.apache.org/documentation/#consumerconfigs_enable.auto.commit for more information",
self.commit_cadence,
)
202 changes: 179 additions & 23 deletions providers/apache/kafka/tests/unit/apache/kafka/operators/test_consume.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,11 @@

import pytest

from airflow.exceptions import AirflowException
from airflow.models import Connection

# Import Operator
from airflow.providers.apache.kafka.operators.consume import ConsumeFromTopicOperator
from airflow.providers.apache.kafka.operators.consume import VALID_COMMIT_CADENCE, ConsumeFromTopicOperator

log = logging.getLogger(__name__)

Expand All @@ -40,6 +41,49 @@ def _no_op(*args, **kwargs) -> Any:
return args, kwargs


def create_mock_kafka_consumer(
message_count: int = 1001, message_content: Any = "test_message", track_consumed_messages: bool = False
) -> tuple[mock.MagicMock, mock.MagicMock, list[int] | None]:
"""
Creates a mock Kafka consumer with configurable behavior.

:param message_count: Number of messages to generate
:param message_content: Content of each message
:param track_consumed_messages: Whether to track total consumed messages
:return: Tuple of (mock_consumer, mock_get_consumer, total_consumed_messages)
"""
# Initialize messages and tracking variables
mocked_messages = [message_content for _ in range(message_count)]
total_consumed_count = [0] if track_consumed_messages else None

# Define the mock consume behavior
def mock_consume(num_messages=0, timeout=-1):
nonlocal mocked_messages
if num_messages < 0:
raise Exception("Number of messages needs to be positive")

msg_count = min(num_messages, len(mocked_messages))
returned_messages = mocked_messages[:msg_count]
mocked_messages = mocked_messages[msg_count:]

if track_consumed_messages:
total_consumed_count[0] += msg_count

return returned_messages

# Create mock objects
mock_consumer = mock.MagicMock()
mock_consumer.consume = mock_consume

mock_get_consumer = mock.patch(
"airflow.providers.apache.kafka.hooks.consume.KafkaConsumerHook.get_consumer",
return_value=mock_consumer,
)

#
return mock_consumer, mock_get_consumer, total_consumed_count # type: ignore[return-value]


class TestConsumeFromTopic:
"""
Test ConsumeFromTopic
Expand Down Expand Up @@ -90,28 +134,13 @@ def test_operator_callable(self):
],
)
def test_operator_consume(self, max_messages, expected_consumed_messages):
total_consumed_messages = 0
mocked_messages = ["test_messages" for i in range(1001)]

def mock_consume(num_messages=0, timeout=-1):
nonlocal mocked_messages
nonlocal total_consumed_messages
if num_messages < 0:
raise Exception("Number of messages needs to be positive")
msg_count = min(num_messages, len(mocked_messages))
returned_messages = mocked_messages[:msg_count]
mocked_messages = mocked_messages[msg_count:]
total_consumed_messages += msg_count
return returned_messages

mock_consumer = mock.MagicMock()
mock_consumer.consume = mock_consume

with mock.patch(
"airflow.providers.apache.kafka.hooks.consume.KafkaConsumerHook.get_consumer"
) as mock_get_consumer:
mock_get_consumer.return_value = mock_consumer
# Create mock consumer with tracking of consumed messages
_, mock_get_consumer, consumed_messages = create_mock_kafka_consumer(
message_count=1001, message_content="test_messages", track_consumed_messages=True
)

# Use the mock
with mock_get_consumer:
operator = ConsumeFromTopicOperator(
kafka_config_id="kafka_d",
topics=["test"],
Expand All @@ -122,4 +151,131 @@ def mock_consume(num_messages=0, timeout=-1):

# execute the operator (this is essentially a no op as we're mocking the consumer)
operator.execute(context={})
assert total_consumed_messages == expected_consumed_messages
assert consumed_messages[0] == expected_consumed_messages

@pytest.mark.parametrize(
"commit_cadence",
[
# will raise AirflowException for invalid commit_cadence
("invalid_cadence"),
("end_of_operator"),
("end_of_batch"),
("never"),
],
)
def test__validate_commit_cadence_on_construct(self, commit_cadence):
operator_kwargs = {
"kafka_config_id": "kafka_d",
"topics": ["test"],
"task_id": "test",
"commit_cadence": commit_cadence,
}
# early return for invalid commit_cadence
if commit_cadence == "invalid_cadence":
with pytest.raises(
AirflowException,
match=f"commit_cadence must be one of {VALID_COMMIT_CADENCE}. Got invalid_cadence",
):
ConsumeFromTopicOperator(**operator_kwargs)
return

# should not raise AirflowException for valid commit_cadence
ConsumeFromTopicOperator(**operator_kwargs)

@pytest.mark.parametrize(
"commit_cadence, enable_auto_commit, expected_warning",
[
# will not log warning if set 'enable.auto.commit' to false
("end_of_operator", "false", False),
("end_of_batch", "false", False),
("never", "false", False),
# will log warning if set 'enable.auto.commit' to true
("end_of_operator", "true", True),
("end_of_batch", "true", True),
("never", "true", True),
# will log warning if 'enable.auto.commit' is not set
("end_of_operator", None, True),
("end_of_batch", None, True),
("never", None, True),
# will not log warning if commit_cadence is None, no matter the value of 'enable.auto.commit'
(None, None, False),
(None, "true", False),
(None, "false", False),
],
)
def test__validate_commit_cadence_before_execute(
self, commit_cadence, enable_auto_commit, expected_warning
):
# mock connection and hook
mocked_hook = mock.MagicMock()
mocked_hook.get_connection.return_value.extra_dejson = (
{} if enable_auto_commit is None else {"enable.auto.commit": enable_auto_commit}
)

with (
mock.patch(
"airflow.providers.apache.kafka.operators.consume.ConsumeFromTopicOperator.hook",
new_callable=mock.PropertyMock,
return_value=mocked_hook,
),
mock.patch(
"airflow.providers.apache.kafka.operators.consume.ConsumeFromTopicOperator.log"
) as mock_log,
):
operator = ConsumeFromTopicOperator(
kafka_config_id="kafka_d",
topics=["test"],
task_id="test",
commit_cadence=commit_cadence,
)
operator._validate_commit_cadence_before_execute()
if expected_warning:
expected_warning_template = (
"To respect commit_cadence='%s', "
"'enable.auto.commit' should be set to 'false' in the Kafka connection configuration. "
"Currently, 'enable.auto.commit' is not explicitly set, so it defaults to 'true', which causes "
"the consumer to auto-commit offsets every 5 seconds. "
"See: https://kafka.apache.org/documentation/#consumerconfigs_enable.auto.commit for more information"
)
mock_log.warning.assert_called_with(expected_warning_template, commit_cadence)
else:
mock_log.warning.assert_not_called()

@pytest.mark.parametrize(
"commit_cadence, max_messages, expected_commit_calls",
[
# end_of_operator: should call commit once at the end
("end_of_operator", 1500, 1),
# end_of_batch: should call commit after each batch (2 batches for 1500 messages with default batch size 1000)
# and a final commit at the end of execute (since commit_cadence is not 'never')
("end_of_batch", 1500, 3),
# never: should never call commit
("never", 1500, 0),
],
)
def test_commit_cadence_behavior(self, commit_cadence, max_messages, expected_commit_calls):
# Create mock consumer with 1500 messages (will use 1001 for the first batch)
mock_consumer, mock_get_consumer, _ = create_mock_kafka_consumer(
message_count=1001, # Only need to create 1001 messages for the first batch
)

# Use the mocks
with mock_get_consumer:
# Create and execute the operator
operator = ConsumeFromTopicOperator(
kafka_config_id="kafka_d",
topics=["test"],
task_id="test",
poll_timeout=0.0001,
max_messages=max_messages,
commit_cadence=commit_cadence,
apply_function=_no_op,
)

operator.execute(context={})

# Verify commit was called the expected number of times
assert mock_consumer.commit.call_count == expected_commit_calls

# Verify consumer was closed
mock_consumer.close.assert_called_once()