diff --git a/providers/apache/kafka/docs/operators/index.rst b/providers/apache/kafka/docs/operators/index.rst index 0119b768714e6..f53e7005af0f3 100644 --- a/providers/apache/kafka/docs/operators/index.rst +++ b/providers/apache/kafka/docs/operators/index.rst @@ -1,4 +1,4 @@ - .. Licensed to the Apache Software Foundation (ASF) under one +.. Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements. See the NOTICE file distributed with this work for additional information regarding copyright ownership. The ASF licenses this file @@ -29,6 +29,9 @@ The operator creates a Kafka Consumer that reads a batch of messages from the cl For parameter definitions take a look at :class:`~airflow.providers.apache.kafka.operators.consume.ConsumeFromTopicOperator`. +.. important:: + If you set the ``commit_cadence`` parameter, ensure that the ``enable.auto.commit`` option in the Kafka connection configuration is explicitly set to ``false``. + By default, ``enable.auto.commit`` is ``true``, which causes the consumer to auto-commit offsets every 5 seconds, potentially overriding the behavior defined by ``commit_cadence``. Using the operator """""""""""""""""" diff --git a/providers/apache/kafka/src/airflow/providers/apache/kafka/operators/consume.py b/providers/apache/kafka/src/airflow/providers/apache/kafka/operators/consume.py index 6abdb7ed8243d..a80c304fdef2a 100644 --- a/providers/apache/kafka/src/airflow/providers/apache/kafka/operators/consume.py +++ b/providers/apache/kafka/src/airflow/providers/apache/kafka/operators/consume.py @@ -17,7 +17,7 @@ from __future__ import annotations from collections.abc import Callable, Sequence -from functools import partial +from functools import cached_property, partial from typing import Any from airflow.exceptions import AirflowException @@ -82,7 +82,7 @@ def __init__( apply_function_batch: Callable[..., Any] | str | None = None, apply_function_args: Sequence[Any] | None = None, apply_function_kwargs: dict[Any, Any] | None = None, - commit_cadence: str | None = "end_of_operator", + commit_cadence: str = "end_of_operator", max_messages: int | None = None, max_batch_size: int = 1000, poll_timeout: float = 60, @@ -102,11 +102,7 @@ def __init__( self.poll_timeout = poll_timeout self.read_to_end = self.max_messages is None - - if self.commit_cadence not in VALID_COMMIT_CADENCE: - raise AirflowException( - f"commit_cadence must be one of {VALID_COMMIT_CADENCE}. Got {self.commit_cadence}" - ) + self._validate_commit_cadence_on_construct() if self.max_messages is not None and self.max_batch_size > self.max_messages: self.log.warning( @@ -117,16 +113,19 @@ def __init__( ) self.max_messages = self.max_batch_size - if self.commit_cadence == "never": - self.commit_cadence = None - if apply_function and apply_function_batch: raise AirflowException( "One of apply_function or apply_function_batch must be supplied, not both." ) + @cached_property + def hook(self): + """Return the KafkaConsumerHook instance.""" + return KafkaConsumerHook(topics=self.topics, kafka_config_id=self.kafka_config_id) + def execute(self, context) -> Any: - consumer = KafkaConsumerHook(topics=self.topics, kafka_config_id=self.kafka_config_id).get_consumer() + self._validate_commit_cadence_before_execute() + consumer = self.hook.get_consumer() if isinstance(self.apply_function, str): self.apply_function = import_string(self.apply_function) @@ -177,10 +176,33 @@ def execute(self, context) -> Any: self.log.info("committing offset at %s", self.commit_cadence) consumer.commit() - if self.commit_cadence: + if self.commit_cadence != "never": self.log.info("committing offset at %s", self.commit_cadence) consumer.commit() consumer.close() return + + def _validate_commit_cadence_on_construct(self): + """Validate the commit_cadence parameter when the operator is constructed.""" + if self.commit_cadence and self.commit_cadence not in VALID_COMMIT_CADENCE: + raise AirflowException( + f"commit_cadence must be one of {VALID_COMMIT_CADENCE}. Got {self.commit_cadence}" + ) + + def _validate_commit_cadence_before_execute(self): + """Validate the commit_cadence parameter before executing the operator.""" + kafka_config = self.hook.get_connection(self.kafka_config_id).extra_dejson + # Same as kafka's behavior, default to "true" if not set + enable_auto_commit = str(kafka_config.get("enable.auto.commit", "true")).lower() + + if self.commit_cadence and enable_auto_commit != "false": + self.log.warning( + "To respect commit_cadence='%s', " + "'enable.auto.commit' should be set to 'false' in the Kafka connection configuration. " + "Currently, 'enable.auto.commit' is not explicitly set, so it defaults to 'true', which causes " + "the consumer to auto-commit offsets every 5 seconds. " + "See: https://kafka.apache.org/documentation/#consumerconfigs_enable.auto.commit for more information", + self.commit_cadence, + ) diff --git a/providers/apache/kafka/tests/unit/apache/kafka/operators/test_consume.py b/providers/apache/kafka/tests/unit/apache/kafka/operators/test_consume.py index facd0be92c976..65cad50cec43c 100644 --- a/providers/apache/kafka/tests/unit/apache/kafka/operators/test_consume.py +++ b/providers/apache/kafka/tests/unit/apache/kafka/operators/test_consume.py @@ -23,10 +23,11 @@ import pytest +from airflow.exceptions import AirflowException from airflow.models import Connection # Import Operator -from airflow.providers.apache.kafka.operators.consume import ConsumeFromTopicOperator +from airflow.providers.apache.kafka.operators.consume import VALID_COMMIT_CADENCE, ConsumeFromTopicOperator log = logging.getLogger(__name__) @@ -40,6 +41,49 @@ def _no_op(*args, **kwargs) -> Any: return args, kwargs +def create_mock_kafka_consumer( + message_count: int = 1001, message_content: Any = "test_message", track_consumed_messages: bool = False +) -> tuple[mock.MagicMock, mock.MagicMock, list[int] | None]: + """ + Creates a mock Kafka consumer with configurable behavior. + + :param message_count: Number of messages to generate + :param message_content: Content of each message + :param track_consumed_messages: Whether to track total consumed messages + :return: Tuple of (mock_consumer, mock_get_consumer, total_consumed_messages) + """ + # Initialize messages and tracking variables + mocked_messages = [message_content for _ in range(message_count)] + total_consumed_count = [0] if track_consumed_messages else None + + # Define the mock consume behavior + def mock_consume(num_messages=0, timeout=-1): + nonlocal mocked_messages + if num_messages < 0: + raise Exception("Number of messages needs to be positive") + + msg_count = min(num_messages, len(mocked_messages)) + returned_messages = mocked_messages[:msg_count] + mocked_messages = mocked_messages[msg_count:] + + if track_consumed_messages: + total_consumed_count[0] += msg_count + + return returned_messages + + # Create mock objects + mock_consumer = mock.MagicMock() + mock_consumer.consume = mock_consume + + mock_get_consumer = mock.patch( + "airflow.providers.apache.kafka.hooks.consume.KafkaConsumerHook.get_consumer", + return_value=mock_consumer, + ) + + # + return mock_consumer, mock_get_consumer, total_consumed_count # type: ignore[return-value] + + class TestConsumeFromTopic: """ Test ConsumeFromTopic @@ -90,28 +134,13 @@ def test_operator_callable(self): ], ) def test_operator_consume(self, max_messages, expected_consumed_messages): - total_consumed_messages = 0 - mocked_messages = ["test_messages" for i in range(1001)] - - def mock_consume(num_messages=0, timeout=-1): - nonlocal mocked_messages - nonlocal total_consumed_messages - if num_messages < 0: - raise Exception("Number of messages needs to be positive") - msg_count = min(num_messages, len(mocked_messages)) - returned_messages = mocked_messages[:msg_count] - mocked_messages = mocked_messages[msg_count:] - total_consumed_messages += msg_count - return returned_messages - - mock_consumer = mock.MagicMock() - mock_consumer.consume = mock_consume - - with mock.patch( - "airflow.providers.apache.kafka.hooks.consume.KafkaConsumerHook.get_consumer" - ) as mock_get_consumer: - mock_get_consumer.return_value = mock_consumer + # Create mock consumer with tracking of consumed messages + _, mock_get_consumer, consumed_messages = create_mock_kafka_consumer( + message_count=1001, message_content="test_messages", track_consumed_messages=True + ) + # Use the mock + with mock_get_consumer: operator = ConsumeFromTopicOperator( kafka_config_id="kafka_d", topics=["test"], @@ -122,4 +151,131 @@ def mock_consume(num_messages=0, timeout=-1): # execute the operator (this is essentially a no op as we're mocking the consumer) operator.execute(context={}) - assert total_consumed_messages == expected_consumed_messages + assert consumed_messages[0] == expected_consumed_messages + + @pytest.mark.parametrize( + "commit_cadence", + [ + # will raise AirflowException for invalid commit_cadence + ("invalid_cadence"), + ("end_of_operator"), + ("end_of_batch"), + ("never"), + ], + ) + def test__validate_commit_cadence_on_construct(self, commit_cadence): + operator_kwargs = { + "kafka_config_id": "kafka_d", + "topics": ["test"], + "task_id": "test", + "commit_cadence": commit_cadence, + } + # early return for invalid commit_cadence + if commit_cadence == "invalid_cadence": + with pytest.raises( + AirflowException, + match=f"commit_cadence must be one of {VALID_COMMIT_CADENCE}. Got invalid_cadence", + ): + ConsumeFromTopicOperator(**operator_kwargs) + return + + # should not raise AirflowException for valid commit_cadence + ConsumeFromTopicOperator(**operator_kwargs) + + @pytest.mark.parametrize( + "commit_cadence, enable_auto_commit, expected_warning", + [ + # will not log warning if set 'enable.auto.commit' to false + ("end_of_operator", "false", False), + ("end_of_batch", "false", False), + ("never", "false", False), + # will log warning if set 'enable.auto.commit' to true + ("end_of_operator", "true", True), + ("end_of_batch", "true", True), + ("never", "true", True), + # will log warning if 'enable.auto.commit' is not set + ("end_of_operator", None, True), + ("end_of_batch", None, True), + ("never", None, True), + # will not log warning if commit_cadence is None, no matter the value of 'enable.auto.commit' + (None, None, False), + (None, "true", False), + (None, "false", False), + ], + ) + def test__validate_commit_cadence_before_execute( + self, commit_cadence, enable_auto_commit, expected_warning + ): + # mock connection and hook + mocked_hook = mock.MagicMock() + mocked_hook.get_connection.return_value.extra_dejson = ( + {} if enable_auto_commit is None else {"enable.auto.commit": enable_auto_commit} + ) + + with ( + mock.patch( + "airflow.providers.apache.kafka.operators.consume.ConsumeFromTopicOperator.hook", + new_callable=mock.PropertyMock, + return_value=mocked_hook, + ), + mock.patch( + "airflow.providers.apache.kafka.operators.consume.ConsumeFromTopicOperator.log" + ) as mock_log, + ): + operator = ConsumeFromTopicOperator( + kafka_config_id="kafka_d", + topics=["test"], + task_id="test", + commit_cadence=commit_cadence, + ) + operator._validate_commit_cadence_before_execute() + if expected_warning: + expected_warning_template = ( + "To respect commit_cadence='%s', " + "'enable.auto.commit' should be set to 'false' in the Kafka connection configuration. " + "Currently, 'enable.auto.commit' is not explicitly set, so it defaults to 'true', which causes " + "the consumer to auto-commit offsets every 5 seconds. " + "See: https://kafka.apache.org/documentation/#consumerconfigs_enable.auto.commit for more information" + ) + mock_log.warning.assert_called_with(expected_warning_template, commit_cadence) + else: + mock_log.warning.assert_not_called() + + @pytest.mark.parametrize( + "commit_cadence, max_messages, expected_commit_calls", + [ + # end_of_operator: should call commit once at the end + ("end_of_operator", 1500, 1), + # end_of_batch: should call commit after each batch (2 batches for 1500 messages with default batch size 1000) + # and a final commit at the end of execute (since commit_cadence is not 'never') + ("end_of_batch", 1500, 3), + # never: should never call commit + ("never", 1500, 0), + ], + ) + def test_commit_cadence_behavior(self, commit_cadence, max_messages, expected_commit_calls): + # Create mock consumer with 1500 messages (will use 1001 for the first batch) + mock_consumer, mock_get_consumer, _ = create_mock_kafka_consumer( + message_count=1001, # Only need to create 1001 messages for the first batch + ) + + # Use the mocks + with mock_get_consumer: + # Create and execute the operator + operator = ConsumeFromTopicOperator( + kafka_config_id="kafka_d", + topics=["test"], + task_id="test", + poll_timeout=0.0001, + max_messages=max_messages, + commit_cadence=commit_cadence, + apply_function=_no_op, + ) + + operator.execute(context={}) + + # Verify commit was called the expected number of times + assert mock_consumer.commit.call_count == expected_commit_calls + + # Verify consumer was closed + mock_consumer.close.assert_called_once()