diff --git a/providers/standard/src/airflow/providers/standard/operators/python.py b/providers/standard/src/airflow/providers/standard/operators/python.py index 5bfc546cf0547..3d8902dcfc0d1 100644 --- a/providers/standard/src/airflow/providers/standard/operators/python.py +++ b/providers/standard/src/airflow/providers/standard/operators/python.py @@ -26,6 +26,7 @@ import sys import textwrap import types +import warnings from abc import ABCMeta, abstractmethod from collections.abc import Collection, Container, Iterable, Mapping, Sequence from functools import cache @@ -38,6 +39,7 @@ from airflow.exceptions import ( AirflowConfigException, AirflowException, + AirflowProviderDeprecationWarning, AirflowSkipException, DeserializingResultError, ) @@ -1113,6 +1115,13 @@ def my_task(): was starting to execute. """ if AIRFLOW_V_3_0_PLUS: + warnings.warn( + "Using get_current_context from standard provider is deprecated and will be removed." + "Please import `from airflow.sdk import get_current_context` and use it instead.", + AirflowProviderDeprecationWarning, + stacklevel=2, + ) + from airflow.sdk import get_current_context return get_current_context() diff --git a/providers/standard/tests/unit/standard/operators/test_python.py b/providers/standard/tests/unit/standard/operators/test_python.py index 45a1e395e1a16..9f60d404837dd 100644 --- a/providers/standard/tests/unit/standard/operators/test_python.py +++ b/providers/standard/tests/unit/standard/operators/test_python.py @@ -44,6 +44,7 @@ from airflow.config_templates.airflow_local_settings import DEFAULT_LOGGING_CONFIG from airflow.exceptions import ( AirflowException, + AirflowProviderDeprecationWarning, DeserializingResultError, ) from airflow.models.baseoperator import BaseOperator @@ -1810,22 +1811,35 @@ def default_kwargs(*, python_version=DEFAULT_PYTHON_VERSION, **kwargs): class TestCurrentContext: def test_current_context_no_context_raise(self): - with pytest.raises(RuntimeError): - get_current_context() + if AIRFLOW_V_3_0_PLUS: + with pytest.warns(AirflowProviderDeprecationWarning): + with pytest.raises(RuntimeError): + get_current_context() + else: + with pytest.raises(RuntimeError): + get_current_context() def test_current_context_roundtrip(self): example_context = {"Hello": "World"} - with set_current_context(example_context): - assert get_current_context() == example_context + if AIRFLOW_V_3_0_PLUS: + with pytest.warns(AirflowProviderDeprecationWarning): + assert get_current_context() == example_context + else: + assert get_current_context() == example_context def test_context_removed_after_exit(self): example_context = {"Hello": "World"} with set_current_context(example_context): pass - with pytest.raises(RuntimeError): - get_current_context() + if AIRFLOW_V_3_0_PLUS: + with pytest.warns(AirflowProviderDeprecationWarning): + with pytest.raises(RuntimeError): + get_current_context() + else: + with pytest.raises(RuntimeError): + get_current_context() def test_nested_context(self): """ @@ -1842,12 +1856,21 @@ def test_nested_context(self): ctx_obj = set_current_context(new_context) ctx_obj.__enter__() ctx_list.append(ctx_obj) - for i in reversed(range(max_stack_depth)): - # Iterate over contexts in reverse order - stack is LIFO - ctx = get_current_context() - assert ctx["ContextId"] == i - # End of with statement - ctx_list[i].__exit__(None, None, None) + if AIRFLOW_V_3_0_PLUS: + with pytest.warns(AirflowProviderDeprecationWarning): + for i in reversed(range(max_stack_depth)): + # Iterate over contexts in reverse order - stack is LIFO + ctx = get_current_context() + assert ctx["ContextId"] == i + # End of with statement + ctx_list[i].__exit__(None, None, None) + else: + for i in reversed(range(max_stack_depth)): + # Iterate over contexts in reverse order - stack is LIFO + ctx = get_current_context() + assert ctx["ContextId"] == i + # End of with statement + ctx_list[i].__exit__(None, None, None) class MyContextAssertOperator(BaseOperator): @@ -1889,12 +1912,20 @@ class TestCurrentContextRuntime: def test_context_in_task(self): with DAG(dag_id="assert_context_dag", default_args=DEFAULT_ARGS, schedule="@once"): op = MyContextAssertOperator(task_id="assert_context") - op.run(ignore_first_depends_on_past=True, ignore_ti_state=True) + if AIRFLOW_V_3_0_PLUS: + with pytest.warns(AirflowProviderDeprecationWarning): + op.run(ignore_first_depends_on_past=True, ignore_ti_state=True) + else: + op.run(ignore_first_depends_on_past=True, ignore_ti_state=True) def test_get_context_in_old_style_context_task(self): with DAG(dag_id="edge_case_context_dag", default_args=DEFAULT_ARGS, schedule="@once"): op = PythonOperator(python_callable=get_all_the_context, task_id="get_all_the_context") - op.run(ignore_first_depends_on_past=True, ignore_ti_state=True) + if AIRFLOW_V_3_0_PLUS: + with pytest.warns(AirflowProviderDeprecationWarning): + op.run(ignore_first_depends_on_past=True, ignore_ti_state=True) + else: + op.run(ignore_first_depends_on_past=True, ignore_ti_state=True) @pytest.mark.need_serialized_dag(False)