From 7c502686f3bf5392c3f32ad9d35ceec893d5c3dd Mon Sep 17 00:00:00 2001 From: LIU ZHE YOU Date: Sun, 22 Jun 2025 18:40:24 +0800 Subject: [PATCH 1/5] Add validation for commit_cadence --- .../apache/kafka/operators/consume.py | 42 ++++++++--- .../apache/kafka/operators/test_consume.py | 71 ++++++++++++++++++- 2 files changed, 102 insertions(+), 11 deletions(-) diff --git a/providers/apache/kafka/src/airflow/providers/apache/kafka/operators/consume.py b/providers/apache/kafka/src/airflow/providers/apache/kafka/operators/consume.py index 6abdb7ed8243d..ee08a14767377 100644 --- a/providers/apache/kafka/src/airflow/providers/apache/kafka/operators/consume.py +++ b/providers/apache/kafka/src/airflow/providers/apache/kafka/operators/consume.py @@ -17,7 +17,7 @@ from __future__ import annotations from collections.abc import Callable, Sequence -from functools import partial +from functools import cached_property, partial from typing import Any from airflow.exceptions import AirflowException @@ -102,11 +102,7 @@ def __init__( self.poll_timeout = poll_timeout self.read_to_end = self.max_messages is None - - if self.commit_cadence not in VALID_COMMIT_CADENCE: - raise AirflowException( - f"commit_cadence must be one of {VALID_COMMIT_CADENCE}. Got {self.commit_cadence}" - ) + self._validate_commit_cadence() if self.max_messages is not None and self.max_batch_size > self.max_messages: self.log.warning( @@ -117,16 +113,18 @@ def __init__( ) self.max_messages = self.max_batch_size - if self.commit_cadence == "never": - self.commit_cadence = None - if apply_function and apply_function_batch: raise AirflowException( "One of apply_function or apply_function_batch must be supplied, not both." ) + @cached_property + def hook(self): + """Return the KafkaConsumerHook instance.""" + return KafkaConsumerHook(topics=self.topics, kafka_config_id=self.kafka_config_id) + def execute(self, context) -> Any: - consumer = KafkaConsumerHook(topics=self.topics, kafka_config_id=self.kafka_config_id).get_consumer() + consumer = self.hook.get_consumer() if isinstance(self.apply_function, str): self.apply_function = import_string(self.apply_function) @@ -184,3 +182,27 @@ def execute(self, context) -> Any: consumer.close() return + + def _validate_commit_cadence(self): + """Validate the commit cadence configuration.""" + if self.commit_cadence and self.commit_cadence not in VALID_COMMIT_CADENCE: + raise AirflowException( + f"commit_cadence must be one of {VALID_COMMIT_CADENCE}. Got {self.commit_cadence}" + ) + + kafka_config = self.hook.get_connection(self.kafka_config_id).extra_dejson + # Same as kafka's behavior, default to "true" if not set + enable_auto_commit = kafka_config.get("enable.auto.commit", "true").lower() + + if self.commit_cadence and enable_auto_commit != "false": + self.log.warning( + "To respect commit_cadence='%s', " + "'enable.auto.commit' should be set to 'false' in the Kafka connection configuration. " + "Currently, 'enable.auto.commit' is not explicitly set, so it defaults to 'true', which causes " + "the consumer to auto-commit offsets every 5 seconds. " + "See: https://kafka.apache.org/documentation/#consumerconfigs_enable.auto.commit", + self.commit_cadence, + ) + + if self.commit_cadence == "never": + self.commit_cadence = None diff --git a/providers/apache/kafka/tests/unit/apache/kafka/operators/test_consume.py b/providers/apache/kafka/tests/unit/apache/kafka/operators/test_consume.py index facd0be92c976..e8bd3b83e7df4 100644 --- a/providers/apache/kafka/tests/unit/apache/kafka/operators/test_consume.py +++ b/providers/apache/kafka/tests/unit/apache/kafka/operators/test_consume.py @@ -23,10 +23,11 @@ import pytest +from airflow.exceptions import AirflowException from airflow.models import Connection # Import Operator -from airflow.providers.apache.kafka.operators.consume import ConsumeFromTopicOperator +from airflow.providers.apache.kafka.operators.consume import VALID_COMMIT_CADENCE, ConsumeFromTopicOperator log = logging.getLogger(__name__) @@ -123,3 +124,71 @@ def mock_consume(num_messages=0, timeout=-1): # execute the operator (this is essentially a no op as we're mocking the consumer) operator.execute(context={}) assert total_consumed_messages == expected_consumed_messages + + @pytest.mark.parametrize( + "commit_cadence, enable_auto_commit, expected_warning", + [ + # will raise AirflowException for invalid commit_cadence + ("invalid_cadence", "false", False), + # will not log warning if set 'enable.auto.commit' to false + ("end_of_operator", "false", False), + ("end_of_batch", "false", False), + ("never", "false", False), + # will log warning if set 'enable.auto.commit' to true + ("end_of_operator", "true", True), + ("end_of_batch", "true", True), + ("never", "true", True), + # will log warning if 'enable.auto.commit' is not set + ("end_of_operator", None, True), + ("end_of_batch", None, True), + ("never", None, True), + # will not log warning if commit_cadence is None, no matter the value of 'enable.auto.commit' + (None, None, False), + (None, "true", False), + (None, "false", False), + ], + ) + def test__validate_commit_cadence(self, commit_cadence, enable_auto_commit, expected_warning): + operator_kwargs = { + "kafka_config_id": "kafka_d", + "topics": ["test"], + "task_id": "test", + "commit_cadence": commit_cadence, + } + # early return for invalid commit_cadence + if commit_cadence == "invalid_cadence": + with pytest.raises( + AirflowException, + match=f"commit_cadence must be one of {VALID_COMMIT_CADENCE}. Got invalid_cadence", + ): + ConsumeFromTopicOperator(**operator_kwargs) + return + + # mock connection and hook + mocked_hook = mock.MagicMock() + mocked_hook.get_connection.return_value.extra_dejson = ( + {} if enable_auto_commit is None else {"enable.auto.commit": enable_auto_commit} + ) + + with ( + mock.patch( + "airflow.providers.apache.kafka.operators.consume.ConsumeFromTopicOperator.hook", + new_callable=mock.PropertyMock, + return_value=mocked_hook, + ), + mock.patch( + "airflow.providers.apache.kafka.operators.consume.ConsumeFromTopicOperator.log" + ) as mock_log, + ): + ConsumeFromTopicOperator(**operator_kwargs) + if expected_warning: + expected_warning_template = ( + "To respect commit_cadence='%s', " + "'enable.auto.commit' should be set to 'false' in the Kafka connection configuration. " + "Currently, 'enable.auto.commit' is not explicitly set, so it defaults to 'true', which causes " + "the consumer to auto-commit offsets every 5 seconds. " + "See: https://kafka.apache.org/documentation/#consumerconfigs_enable.auto.commit" + ) + mock_log.warning.assert_called_with(expected_warning_template, commit_cadence) + else: + mock_log.warning.assert_not_called() From d857ad10db2d33b410ae3fd35594e2cd65e5c028 Mon Sep 17 00:00:00 2001 From: LIU ZHE YOU Date: Sun, 22 Jun 2025 18:41:40 +0800 Subject: [PATCH 2/5] Add important section to mention the behavior --- providers/apache/kafka/docs/operators/index.rst | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/providers/apache/kafka/docs/operators/index.rst b/providers/apache/kafka/docs/operators/index.rst index 0119b768714e6..f53e7005af0f3 100644 --- a/providers/apache/kafka/docs/operators/index.rst +++ b/providers/apache/kafka/docs/operators/index.rst @@ -1,4 +1,4 @@ - .. Licensed to the Apache Software Foundation (ASF) under one +.. Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements. See the NOTICE file distributed with this work for additional information regarding copyright ownership. The ASF licenses this file @@ -29,6 +29,9 @@ The operator creates a Kafka Consumer that reads a batch of messages from the cl For parameter definitions take a look at :class:`~airflow.providers.apache.kafka.operators.consume.ConsumeFromTopicOperator`. +.. important:: + If you set the ``commit_cadence`` parameter, ensure that the ``enable.auto.commit`` option in the Kafka connection configuration is explicitly set to ``false``. + By default, ``enable.auto.commit`` is ``true``, which causes the consumer to auto-commit offsets every 5 seconds, potentially overriding the behavior defined by ``commit_cadence``. Using the operator """""""""""""""""" From 3a76a1c7ac4137ea2e8454393b6262961ec65815 Mon Sep 17 00:00:00 2001 From: LIU ZHE YOU Date: Mon, 23 Jun 2025 17:19:53 +0800 Subject: [PATCH 3/5] Seperate validation with on_construct, on_execute --- .../apache/kafka/operators/consume.py | 18 +++--- .../apache/kafka/operators/test_consume.py | 58 +++++++++++++------ 2 files changed, 48 insertions(+), 28 deletions(-) diff --git a/providers/apache/kafka/src/airflow/providers/apache/kafka/operators/consume.py b/providers/apache/kafka/src/airflow/providers/apache/kafka/operators/consume.py index ee08a14767377..33fffa14fdc0d 100644 --- a/providers/apache/kafka/src/airflow/providers/apache/kafka/operators/consume.py +++ b/providers/apache/kafka/src/airflow/providers/apache/kafka/operators/consume.py @@ -82,7 +82,7 @@ def __init__( apply_function_batch: Callable[..., Any] | str | None = None, apply_function_args: Sequence[Any] | None = None, apply_function_kwargs: dict[Any, Any] | None = None, - commit_cadence: str | None = "end_of_operator", + commit_cadence: str = "end_of_operator", max_messages: int | None = None, max_batch_size: int = 1000, poll_timeout: float = 60, @@ -102,7 +102,7 @@ def __init__( self.poll_timeout = poll_timeout self.read_to_end = self.max_messages is None - self._validate_commit_cadence() + self._validate_commit_cadence_on_construct() if self.max_messages is not None and self.max_batch_size > self.max_messages: self.log.warning( @@ -124,6 +124,7 @@ def hook(self): return KafkaConsumerHook(topics=self.topics, kafka_config_id=self.kafka_config_id) def execute(self, context) -> Any: + self._validate_commit_cadence_before_execute() consumer = self.hook.get_consumer() if isinstance(self.apply_function, str): @@ -175,7 +176,7 @@ def execute(self, context) -> Any: self.log.info("committing offset at %s", self.commit_cadence) consumer.commit() - if self.commit_cadence: + if self.commit_cadence != "never": self.log.info("committing offset at %s", self.commit_cadence) consumer.commit() @@ -183,16 +184,18 @@ def execute(self, context) -> Any: return - def _validate_commit_cadence(self): - """Validate the commit cadence configuration.""" + def _validate_commit_cadence_on_construct(self): + """Validate the commit_cadence parameter when the operator is constructed.""" if self.commit_cadence and self.commit_cadence not in VALID_COMMIT_CADENCE: raise AirflowException( f"commit_cadence must be one of {VALID_COMMIT_CADENCE}. Got {self.commit_cadence}" ) + def _validate_commit_cadence_before_execute(self): + """Validate the commit_cadence parameter before executing the operator.""" kafka_config = self.hook.get_connection(self.kafka_config_id).extra_dejson # Same as kafka's behavior, default to "true" if not set - enable_auto_commit = kafka_config.get("enable.auto.commit", "true").lower() + enable_auto_commit = str(kafka_config.get("enable.auto.commit", "true")).lower() if self.commit_cadence and enable_auto_commit != "false": self.log.warning( @@ -203,6 +206,3 @@ def _validate_commit_cadence(self): "See: https://kafka.apache.org/documentation/#consumerconfigs_enable.auto.commit", self.commit_cadence, ) - - if self.commit_cadence == "never": - self.commit_cadence = None diff --git a/providers/apache/kafka/tests/unit/apache/kafka/operators/test_consume.py b/providers/apache/kafka/tests/unit/apache/kafka/operators/test_consume.py index e8bd3b83e7df4..2a7bf1299dbcc 100644 --- a/providers/apache/kafka/tests/unit/apache/kafka/operators/test_consume.py +++ b/providers/apache/kafka/tests/unit/apache/kafka/operators/test_consume.py @@ -126,10 +126,37 @@ def mock_consume(num_messages=0, timeout=-1): assert total_consumed_messages == expected_consumed_messages @pytest.mark.parametrize( - "commit_cadence, enable_auto_commit, expected_warning", + "commit_cadence", [ # will raise AirflowException for invalid commit_cadence - ("invalid_cadence", "false", False), + ("invalid_cadence"), + ("end_of_operator"), + ("end_of_batch"), + ("never"), + ], + ) + def test__validate_commit_cadence_on_construct(self, commit_cadence): + operator_kwargs = { + "kafka_config_id": "kafka_d", + "topics": ["test"], + "task_id": "test", + "commit_cadence": commit_cadence, + } + # early return for invalid commit_cadence + if commit_cadence == "invalid_cadence": + with pytest.raises( + AirflowException, + match=f"commit_cadence must be one of {VALID_COMMIT_CADENCE}. Got invalid_cadence", + ): + ConsumeFromTopicOperator(**operator_kwargs) + return + + # should not raise AirflowException for valid commit_cadence + ConsumeFromTopicOperator(**operator_kwargs) + + @pytest.mark.parametrize( + "commit_cadence, enable_auto_commit, expected_warning", + [ # will not log warning if set 'enable.auto.commit' to false ("end_of_operator", "false", False), ("end_of_batch", "false", False), @@ -148,22 +175,9 @@ def mock_consume(num_messages=0, timeout=-1): (None, "false", False), ], ) - def test__validate_commit_cadence(self, commit_cadence, enable_auto_commit, expected_warning): - operator_kwargs = { - "kafka_config_id": "kafka_d", - "topics": ["test"], - "task_id": "test", - "commit_cadence": commit_cadence, - } - # early return for invalid commit_cadence - if commit_cadence == "invalid_cadence": - with pytest.raises( - AirflowException, - match=f"commit_cadence must be one of {VALID_COMMIT_CADENCE}. Got invalid_cadence", - ): - ConsumeFromTopicOperator(**operator_kwargs) - return - + def test__validate_commit_cadence_before_execute( + self, commit_cadence, enable_auto_commit, expected_warning + ): # mock connection and hook mocked_hook = mock.MagicMock() mocked_hook.get_connection.return_value.extra_dejson = ( @@ -180,7 +194,13 @@ def test__validate_commit_cadence(self, commit_cadence, enable_auto_commit, expe "airflow.providers.apache.kafka.operators.consume.ConsumeFromTopicOperator.log" ) as mock_log, ): - ConsumeFromTopicOperator(**operator_kwargs) + operator = ConsumeFromTopicOperator( + kafka_config_id="kafka_d", + topics=["test"], + task_id="test", + commit_cadence=commit_cadence, + ) + operator._validate_commit_cadence_before_execute() if expected_warning: expected_warning_template = ( "To respect commit_cadence='%s', " From 88acbcc40137be46a877899d145ca6d54a594e87 Mon Sep 17 00:00:00 2001 From: LIU ZHE YOU Date: Mon, 23 Jun 2025 20:17:04 +0800 Subject: [PATCH 4/5] Refactor unit test - modularize mock consumer - add test_commit_cadence_behavior --- .../apache/kafka/operators/test_consume.py | 111 ++++++++++++++---- 1 file changed, 89 insertions(+), 22 deletions(-) diff --git a/providers/apache/kafka/tests/unit/apache/kafka/operators/test_consume.py b/providers/apache/kafka/tests/unit/apache/kafka/operators/test_consume.py index 2a7bf1299dbcc..5e78cb7fa4c3e 100644 --- a/providers/apache/kafka/tests/unit/apache/kafka/operators/test_consume.py +++ b/providers/apache/kafka/tests/unit/apache/kafka/operators/test_consume.py @@ -41,6 +41,49 @@ def _no_op(*args, **kwargs) -> Any: return args, kwargs +def create_mock_kafka_consumer( + message_count: int = 1001, message_content: Any = "test_message", track_consumed_messages: bool = False +) -> tuple[mock.MagicMock, mock.MagicMock, list[int] | None]: + """ + Creates a mock Kafka consumer with configurable behavior. + + :param message_count: Number of messages to generate + :param message_content: Content of each message + :param track_consumed_messages: Whether to track total consumed messages + :return: Tuple of (mock_consumer, mock_get_consumer, total_consumed_messages) + """ + # Initialize messages and tracking variables + mocked_messages = [message_content for _ in range(message_count)] + total_consumed_count = [0] if track_consumed_messages else None + + # Define the mock consume behavior + def mock_consume(num_messages=0, timeout=-1): + nonlocal mocked_messages + if num_messages < 0: + raise Exception("Number of messages needs to be positive") + + msg_count = min(num_messages, len(mocked_messages)) + returned_messages = mocked_messages[:msg_count] + mocked_messages = mocked_messages[msg_count:] + + if track_consumed_messages: + total_consumed_count[0] += msg_count + + return returned_messages + + # Create mock objects + mock_consumer = mock.MagicMock() + mock_consumer.consume = mock_consume + + mock_get_consumer = mock.patch( + "airflow.providers.apache.kafka.hooks.consume.KafkaConsumerHook.get_consumer", + return_value=mock_consumer, + ) + + # + return mock_consumer, mock_get_consumer, total_consumed_count # type: ignore[return-value] + + class TestConsumeFromTopic: """ Test ConsumeFromTopic @@ -91,28 +134,13 @@ def test_operator_callable(self): ], ) def test_operator_consume(self, max_messages, expected_consumed_messages): - total_consumed_messages = 0 - mocked_messages = ["test_messages" for i in range(1001)] - - def mock_consume(num_messages=0, timeout=-1): - nonlocal mocked_messages - nonlocal total_consumed_messages - if num_messages < 0: - raise Exception("Number of messages needs to be positive") - msg_count = min(num_messages, len(mocked_messages)) - returned_messages = mocked_messages[:msg_count] - mocked_messages = mocked_messages[msg_count:] - total_consumed_messages += msg_count - return returned_messages - - mock_consumer = mock.MagicMock() - mock_consumer.consume = mock_consume - - with mock.patch( - "airflow.providers.apache.kafka.hooks.consume.KafkaConsumerHook.get_consumer" - ) as mock_get_consumer: - mock_get_consumer.return_value = mock_consumer + # Create mock consumer with tracking of consumed messages + _, mock_get_consumer, consumed_messages = create_mock_kafka_consumer( + message_count=1001, message_content="test_messages", track_consumed_messages=True + ) + # Use the mock + with mock_get_consumer: operator = ConsumeFromTopicOperator( kafka_config_id="kafka_d", topics=["test"], @@ -123,7 +151,7 @@ def mock_consume(num_messages=0, timeout=-1): # execute the operator (this is essentially a no op as we're mocking the consumer) operator.execute(context={}) - assert total_consumed_messages == expected_consumed_messages + assert consumed_messages[0] == expected_consumed_messages @pytest.mark.parametrize( "commit_cadence", @@ -212,3 +240,42 @@ def test__validate_commit_cadence_before_execute( mock_log.warning.assert_called_with(expected_warning_template, commit_cadence) else: mock_log.warning.assert_not_called() + + @pytest.mark.parametrize( + "commit_cadence, max_messages, expected_commit_calls", + [ + # end_of_operator: should call commit once at the end + ("end_of_operator", 1500, 1), + # end_of_batch: should call commit after each batch (2 batches for 1500 messages with default batch size 1000) + # and a final commit at the end of execute (since commit_cadence is not 'never') + ("end_of_batch", 1500, 3), + # never: should never call commit + ("never", 1500, 0), + ], + ) + def test_commit_cadence_behavior(self, commit_cadence, max_messages, expected_commit_calls): + # Create mock consumer with 1500 messages (will use 1001 for the first batch) + mock_consumer, mock_get_consumer, _ = create_mock_kafka_consumer( + message_count=1001, # Only need to create 1001 messages for the first batch + ) + + # Use the mocks + with mock_get_consumer: + # Create and execute the operator + operator = ConsumeFromTopicOperator( + kafka_config_id="kafka_d", + topics=["test"], + task_id="test", + poll_timeout=0.0001, + max_messages=max_messages, + commit_cadence=commit_cadence, + apply_function=_no_op, + ) + + operator.execute(context={}) + + # Verify commit was called the expected number of times + assert mock_consumer.commit.call_count == expected_commit_calls + + # Verify consumer was closed + mock_consumer.close.assert_called_once() From deed801e1e93efeeea6e775479a178f9d8af223b Mon Sep 17 00:00:00 2001 From: LIU ZHE YOU <68415893+jason810496@users.noreply.github.com> Date: Tue, 24 Jun 2025 15:51:42 +0800 Subject: [PATCH 5/5] Fix nit for warning message Co-authored-by: Amogh Desai --- .../src/airflow/providers/apache/kafka/operators/consume.py | 2 +- .../kafka/tests/unit/apache/kafka/operators/test_consume.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/providers/apache/kafka/src/airflow/providers/apache/kafka/operators/consume.py b/providers/apache/kafka/src/airflow/providers/apache/kafka/operators/consume.py index 33fffa14fdc0d..a80c304fdef2a 100644 --- a/providers/apache/kafka/src/airflow/providers/apache/kafka/operators/consume.py +++ b/providers/apache/kafka/src/airflow/providers/apache/kafka/operators/consume.py @@ -203,6 +203,6 @@ def _validate_commit_cadence_before_execute(self): "'enable.auto.commit' should be set to 'false' in the Kafka connection configuration. " "Currently, 'enable.auto.commit' is not explicitly set, so it defaults to 'true', which causes " "the consumer to auto-commit offsets every 5 seconds. " - "See: https://kafka.apache.org/documentation/#consumerconfigs_enable.auto.commit", + "See: https://kafka.apache.org/documentation/#consumerconfigs_enable.auto.commit for more information", self.commit_cadence, ) diff --git a/providers/apache/kafka/tests/unit/apache/kafka/operators/test_consume.py b/providers/apache/kafka/tests/unit/apache/kafka/operators/test_consume.py index 5e78cb7fa4c3e..65cad50cec43c 100644 --- a/providers/apache/kafka/tests/unit/apache/kafka/operators/test_consume.py +++ b/providers/apache/kafka/tests/unit/apache/kafka/operators/test_consume.py @@ -235,7 +235,7 @@ def test__validate_commit_cadence_before_execute( "'enable.auto.commit' should be set to 'false' in the Kafka connection configuration. " "Currently, 'enable.auto.commit' is not explicitly set, so it defaults to 'true', which causes " "the consumer to auto-commit offsets every 5 seconds. " - "See: https://kafka.apache.org/documentation/#consumerconfigs_enable.auto.commit" + "See: https://kafka.apache.org/documentation/#consumerconfigs_enable.auto.commit for more information" ) mock_log.warning.assert_called_with(expected_warning_template, commit_cadence) else: