From ce3ab7601a28f7e5c0b5ef568f3d0f807872f17b Mon Sep 17 00:00:00 2001 From: Ramit Kataria Date: Mon, 10 Nov 2025 12:54:09 -0800 Subject: [PATCH] Fix Enum-str interpolation in callback metrics in Python 3.11+ Followup to https://github.com/apache/airflow/pull/57215#pullrequestreview-3426156849 Even though CallbackState is a string Enum and works fine with string comparisons, the default __str__ method was changed in Python 3.11 causing test failures due to unexpected metric name: airflow-core/tests/unit/models/test_callback.py:105: in test_get_metric_info assert metric_info["stat"] == "deadline_alerts.callback_success" E AssertionError: assert equals failed E 'deadline_alerts.callback_CallbackState.SUCCESS' 'deadline_alerts.callback_success' This overrides __str__ in CallbackState so it behaves consistently across python versions. --- airflow-core/src/airflow/models/callback.py | 5 ++++- airflow-core/tests/unit/models/test_callback.py | 6 +++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/airflow-core/src/airflow/models/callback.py b/airflow-core/src/airflow/models/callback.py index bfaa43db6b2b3..15925e16dd4fb 100644 --- a/airflow-core/src/airflow/models/callback.py +++ b/airflow-core/src/airflow/models/callback.py @@ -50,6 +50,9 @@ class CallbackState(str, Enum): SUCCESS = "success" FAILED = "failed" + def __str__(self) -> str: + return self.value + ACTIVE_STATES = frozenset((CallbackState.QUEUED, CallbackState.RUNNING)) TERMINAL_STATES = frozenset((CallbackState.SUCCESS, CallbackState.FAILED)) @@ -153,7 +156,7 @@ def __init__(self, priority_weight: int = 1, prefix: str = "", **kwargs): def queue(self): self.state = CallbackState.QUEUED - def get_metric_info(self, status: str, result: Any) -> dict: + def get_metric_info(self, status: CallbackState, result: Any) -> dict: tags = {"result": result, **self.data} tags.pop("prefix", None) diff --git a/airflow-core/tests/unit/models/test_callback.py b/airflow-core/tests/unit/models/test_callback.py index 36e09713a3d04..ed640ab9a975b 100644 --- a/airflow-core/tests/unit/models/test_callback.py +++ b/airflow-core/tests/unit/models/test_callback.py @@ -100,7 +100,7 @@ class UnknownCallback: def test_get_metric_info(self): callback = TriggererCallback(TEST_ASYNC_CALLBACK, prefix="deadline_alerts", dag_id=TEST_DAG_ID) callback.data["kwargs"] = {"context": {"dag_id": TEST_DAG_ID}, "email": "test@example.com"} - metric_info = callback.get_metric_info(CallbackState.SUCCESS.value, "0") + metric_info = callback.get_metric_info(CallbackState.SUCCESS, "0") assert metric_info["stat"] == "deadline_alerts.callback_success" assert metric_info["tags"] == { @@ -122,7 +122,7 @@ def test_polymorphic_serde(self, session): assert isinstance(retrieved, TriggererCallback) assert retrieved.fetch_method == CallbackFetchMethod.IMPORT_PATH assert retrieved.data == TEST_ASYNC_CALLBACK.serialize() - assert retrieved.state == CallbackState.PENDING + assert retrieved.state == CallbackState.PENDING.value assert retrieved.output is None assert retrieved.priority_weight == 1 assert retrieved.created_at is not None @@ -192,7 +192,7 @@ def test_polymorphic_serde(self, session): assert isinstance(retrieved, ExecutorCallback) assert retrieved.fetch_method == CallbackFetchMethod.IMPORT_PATH assert retrieved.data == TEST_SYNC_CALLBACK.serialize() - assert retrieved.state == CallbackState.PENDING + assert retrieved.state == CallbackState.PENDING.value assert retrieved.output is None assert retrieved.priority_weight == 1 assert retrieved.created_at is not None