diff --git a/airflow/decorators/__init__.pyi b/airflow/decorators/__init__.pyi index faf77e8240d6c..089e453d02b43 100644 --- a/airflow/decorators/__init__.pyi +++ b/airflow/decorators/__init__.pyi @@ -125,6 +125,7 @@ class TaskDecoratorCollection: env_vars: dict[str, str] | None = None, inherit_env: bool = True, use_dill: bool = False, + use_airflow_context: bool = False, **kwargs, ) -> TaskDecorator: """Create a decorator to convert the decorated callable to a virtual environment task. @@ -176,6 +177,7 @@ class TaskDecoratorCollection: :param use_dill: Deprecated, use ``serializer`` instead. Whether to use dill to serialize the args and result (pickle is default). This allows more complex types but requires you to include dill in your requirements. + :param use_airflow_context: Whether to provide ``get_current_context()`` to the python_callable. """ @overload def virtualenv(self, python_callable: Callable[FParams, FReturn]) -> Task[FParams, FReturn]: ... @@ -192,6 +194,7 @@ class TaskDecoratorCollection: env_vars: dict[str, str] | None = None, inherit_env: bool = True, use_dill: bool = False, + use_airflow_context: bool = False, **kwargs, ) -> TaskDecorator: """Create a decorator to convert the decorated callable to a virtual environment task. @@ -225,6 +228,7 @@ class TaskDecoratorCollection: :param use_dill: Deprecated, use ``serializer`` instead. Whether to use dill to serialize the args and result (pickle is default). This allows more complex types but requires you to include dill in your requirements. + :param use_airflow_context: Whether to provide ``get_current_context()`` to the python_callable. """ @overload def branch( # type: ignore[misc] @@ -258,6 +262,7 @@ class TaskDecoratorCollection: venv_cache_path: None | str = None, show_return_value_in_logs: bool = True, use_dill: bool = False, + use_airflow_context: bool = False, **kwargs, ) -> TaskDecorator: """Create a decorator to wrap the decorated callable into a BranchPythonVirtualenvOperator. @@ -299,6 +304,7 @@ class TaskDecoratorCollection: :param use_dill: Deprecated, use ``serializer`` instead. Whether to use dill to serialize the args and result (pickle is default). This allows more complex types but requires you to include dill in your requirements. + :param use_airflow_context: Whether to provide ``get_current_context()`` to the python_callable. """ @overload def branch_virtualenv(self, python_callable: Callable[FParams, FReturn]) -> Task[FParams, FReturn]: ... diff --git a/airflow/example_dags/example_python_context_decorator.py b/airflow/example_dags/example_python_context_decorator.py new file mode 100644 index 0000000000000..497ee08e17cea --- /dev/null +++ b/airflow/example_dags/example_python_context_decorator.py @@ -0,0 +1,92 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Example DAG demonstrating the usage of the PythonOperator with `get_current_context()` to get the current context. + +Also, demonstrates the usage of the TaskFlow API. +""" + +from __future__ import annotations + +import sys + +import pendulum + +from airflow.decorators import dag, task + +SOME_EXTERNAL_PYTHON = sys.executable + + +@dag( + schedule=None, + start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), + catchup=False, + tags=["example"], +) +def example_python_context_decorator(): + # [START get_current_context] + @task(task_id="print_the_context") + def print_context() -> str: + """Print the Airflow context.""" + from pprint import pprint + + from airflow.operators.python import get_current_context + + context = get_current_context() + pprint(context) + return "Whatever you return gets printed in the logs" + + print_the_context = print_context() + # [END get_current_context] + + # [START get_current_context_venv] + @task.virtualenv(task_id="print_the_context_venv", use_airflow_context=True) + def print_context_venv() -> str: + """Print the Airflow context in venv.""" + from pprint import pprint + + from airflow.operators.python import get_current_context + + context = get_current_context() + pprint(context) + return "Whatever you return gets printed in the logs" + + print_the_context_venv = print_context_venv() + # [END get_current_context_venv] + + # [START get_current_context_external] + @task.external_python( + task_id="print_the_context_external", python=SOME_EXTERNAL_PYTHON, use_airflow_context=True + ) + def print_context_external() -> str: + """Print the Airflow context in external python.""" + from pprint import pprint + + from airflow.operators.python import get_current_context + + context = get_current_context() + pprint(context) + return "Whatever you return gets printed in the logs" + + print_the_context_external = print_context_external() + # [END get_current_context_external] + + _ = print_the_context >> [print_the_context_venv, print_the_context_external] + + +example_python_context_decorator() diff --git a/airflow/example_dags/example_python_context_operator.py b/airflow/example_dags/example_python_context_operator.py new file mode 100644 index 0000000000000..f1b76c527cfd6 --- /dev/null +++ b/airflow/example_dags/example_python_context_operator.py @@ -0,0 +1,91 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Example DAG demonstrating the usage of the PythonOperator with `get_current_context()` to get the current context. + +Also, demonstrates the usage of the classic Python operators. +""" + +from __future__ import annotations + +import sys + +import pendulum + +from airflow import DAG +from airflow.operators.python import ExternalPythonOperator, PythonOperator, PythonVirtualenvOperator + +SOME_EXTERNAL_PYTHON = sys.executable + +with DAG( + dag_id="example_python_context_operator", + schedule=None, + start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), + catchup=False, + tags=["example"], +) as dag: + # [START get_current_context] + def print_context() -> str: + """Print the Airflow context.""" + from pprint import pprint + + from airflow.operators.python import get_current_context + + context = get_current_context() + pprint(context) + return "Whatever you return gets printed in the logs" + + print_the_context = PythonOperator(task_id="print_the_context", python_callable=print_context) + # [END get_current_context] + + # [START get_current_context_venv] + def print_context_venv() -> str: + """Print the Airflow context in venv.""" + from pprint import pprint + + from airflow.operators.python import get_current_context + + context = get_current_context() + pprint(context) + return "Whatever you return gets printed in the logs" + + print_the_context_venv = PythonVirtualenvOperator( + task_id="print_the_context_venv", python_callable=print_context_venv, use_airflow_context=True + ) + # [END get_current_context_venv] + + # [START get_current_context_external] + def print_context_external() -> str: + """Print the Airflow context in external python.""" + from pprint import pprint + + from airflow.operators.python import get_current_context + + context = get_current_context() + pprint(context) + return "Whatever you return gets printed in the logs" + + print_the_context_external = ExternalPythonOperator( + task_id="print_the_context_external", + python_callable=print_context_external, + python=SOME_EXTERNAL_PYTHON, + use_airflow_context=True, + ) + # [END get_current_context_external] + + _ = print_the_context >> [print_the_context_venv, print_the_context_external] diff --git a/airflow/operators/python.py b/airflow/operators/python.py index fdfe575fb927f..09b2644beeee9 100644 --- a/airflow/operators/python.py +++ b/airflow/operators/python.py @@ -49,19 +49,23 @@ from airflow.models.taskinstance import _CURRENT_CONTEXT from airflow.models.variable import Variable from airflow.operators.branch import BranchMixIn +from airflow.settings import _ENABLE_AIP_44 from airflow.typing_compat import Literal from airflow.utils import hashlib_wrapper from airflow.utils.context import context_copy_partial, context_get_outlet_events, context_merge from airflow.utils.file import get_unique_dag_module_name from airflow.utils.operator_helpers import ExecutionCallableRunner, KeywordParameters from airflow.utils.process_utils import execute_in_subprocess +from airflow.utils.pydantic import is_pydantic_2_installed from airflow.utils.python_virtualenv import prepare_virtualenv, write_python_script +from airflow.utils.session import create_session log = logging.getLogger(__name__) if TYPE_CHECKING: from pendulum.datetime import DateTime + from airflow.serialization.enums import Encoding from airflow.utils.context import Context @@ -442,6 +446,7 @@ def __init__( env_vars: dict[str, str] | None = None, inherit_env: bool = True, use_dill: bool = False, + use_airflow_context: bool = False, **kwargs, ): if ( @@ -481,6 +486,7 @@ def __init__( f"Expected one of {', '.join(map(repr, _SERIALIZERS))}" ) raise AirflowException(msg) + self.pickling_library = _SERIALIZERS[serializer] self.serializer: _SerializerTypeDef = serializer @@ -494,6 +500,7 @@ def __init__( ) self.env_vars = env_vars self.inherit_env = inherit_env + self.use_airflow_context = use_airflow_context @abstractmethod def _iter_serializable_context_keys(self): @@ -540,10 +547,15 @@ def _execute_python_callable_in_subprocess(self, python_path: Path): string_args_path = tmp_dir / "string_args.txt" script_path = tmp_dir / "script.py" termination_log_path = tmp_dir / "termination.log" + airflow_context_path = tmp_dir / "airflow_context.json" self._write_args(input_path) self._write_string_args(string_args_path) + if self.use_airflow_context and (not is_pydantic_2_installed() or not _ENABLE_AIP_44): + error_msg = "`get_current_context()` needs to be used with Pydantic 2 and AIP-44 enabled." + raise AirflowException(error_msg) + jinja_context = { "op_args": self.op_args, "op_kwargs": op_kwargs, @@ -551,6 +563,7 @@ def _execute_python_callable_in_subprocess(self, python_path: Path): "pickling_library": self.serializer, "python_callable": self.python_callable.__name__, "python_callable_source": self.get_python_source(), + "use_airflow_context": self.use_airflow_context, } if inspect.getfile(self.python_callable) == self.dag.fileloc: @@ -561,6 +574,19 @@ def _execute_python_callable_in_subprocess(self, python_path: Path): filename=os.fspath(script_path), render_template_as_native_obj=self.dag.render_template_as_native_obj, ) + if self.use_airflow_context: + from airflow.serialization.serialized_objects import BaseSerialization + + context = get_current_context() + with create_session() as session: + # FIXME: DetachedInstanceError + dag_run, task_instance = context["dag_run"], context["task_instance"] + session.add_all([dag_run, task_instance]) + serializable_context: dict[Encoding, Any] = BaseSerialization.serialize( + context, use_pydantic_models=True + ) + with airflow_context_path.open("w+") as file: + json.dump(serializable_context, file) env_vars = dict(os.environ) if self.inherit_env else {} if self.env_vars: @@ -575,6 +601,7 @@ def _execute_python_callable_in_subprocess(self, python_path: Path): os.fspath(output_path), os.fspath(string_args_path), os.fspath(termination_log_path), + os.fspath(airflow_context_path), ], env=env_vars, ) @@ -666,6 +693,7 @@ class PythonVirtualenvOperator(_BasePythonVirtualenvOperator): :param use_dill: Deprecated, use ``serializer`` instead. Whether to use dill to serialize the args and result (pickle is default). This allows more complex types but requires you to include dill in your requirements. + :param use_airflow_context: Whether to provide ``get_current_context()`` to the python_callable. """ template_fields: Sequence[str] = tuple( @@ -694,6 +722,7 @@ def __init__( env_vars: dict[str, str] | None = None, inherit_env: bool = True, use_dill: bool = False, + use_airflow_context: bool = False, **kwargs, ): if ( @@ -715,6 +744,9 @@ def __init__( ) if not is_venv_installed(): raise AirflowException("PythonVirtualenvOperator requires virtualenv, please install it.") + if use_airflow_context and (not expect_airflow and not system_site_packages): + error_msg = "use_airflow_context is set to True, but expect_airflow and system_site_packages are set to False." + raise AirflowException(error_msg) if not requirements: self.requirements: list[str] = [] elif isinstance(requirements, str): @@ -744,6 +776,7 @@ def __init__( env_vars=env_vars, inherit_env=inherit_env, use_dill=use_dill, + use_airflow_context=use_airflow_context, **kwargs, ) @@ -962,6 +995,7 @@ class ExternalPythonOperator(_BasePythonVirtualenvOperator): :param use_dill: Deprecated, use ``serializer`` instead. Whether to use dill to serialize the args and result (pickle is default). This allows more complex types but requires you to include dill in your requirements. + :param use_airflow_context: Whether to provide ``get_current_context()`` to the python_callable. """ template_fields: Sequence[str] = tuple({"python"}.union(PythonOperator.template_fields)) @@ -983,10 +1017,14 @@ def __init__( env_vars: dict[str, str] | None = None, inherit_env: bool = True, use_dill: bool = False, + use_airflow_context: bool = False, **kwargs, ): if not python: raise ValueError("Python Path must be defined in ExternalPythonOperator") + if use_airflow_context and not expect_airflow: + error_msg = "use_airflow_context is set to True, but expect_airflow is set to False." + raise AirflowException(error_msg) self.python = python self.expect_pendulum = expect_pendulum super().__init__( @@ -1002,6 +1040,7 @@ def __init__( env_vars=env_vars, inherit_env=inherit_env, use_dill=use_dill, + use_airflow_context=use_airflow_context, **kwargs, ) diff --git a/airflow/utils/python_virtualenv_script.jinja2 b/airflow/utils/python_virtualenv_script.jinja2 index 2ff417985e887..22d68acd755b2 100644 --- a/airflow/utils/python_virtualenv_script.jinja2 +++ b/airflow/utils/python_virtualenv_script.jinja2 @@ -64,6 +64,29 @@ with open(sys.argv[3], "r") as file: virtualenv_string_args = list(map(lambda x: x.strip(), list(file))) {% endif %} +{% if use_airflow_context | default(false) -%} +if len(sys.argv) > 5: + import json + from types import ModuleType + + from airflow.operators import python as airflow_python + from airflow.serialization.serialized_objects import BaseSerialization + + + class _MockPython(ModuleType): + @staticmethod + def get_current_context(): + with open(sys.argv[5]) as file: + context = json.load(file) + return BaseSerialization.deserialize(context, use_pydantic_models=True) + + def __getattr__(self, name: str): + return getattr(airflow_python, name) + + + MockPython = _MockPython("MockPython") + sys.modules["airflow.operators.python"] = MockPython +{% endif %} try: res = {{ python_callable }}(*arg_dict["args"], **arg_dict["kwargs"]) diff --git a/docs/apache-airflow/howto/operator/python.rst b/docs/apache-airflow/howto/operator/python.rst index b8619cd38bce8..2f0defddd886c 100644 --- a/docs/apache-airflow/howto/operator/python.rst +++ b/docs/apache-airflow/howto/operator/python.rst @@ -102,6 +102,37 @@ is evaluated as a :ref:`Jinja template `. :start-after: [START howto_operator_python_render_sql] :end-before: [END howto_operator_python_render_sql] +Context +^^^^^^^ + +The ``Context`` is a dictionary object that contains information +about the environment of the ``DagRun``. +For example, selecting ``task_instance`` will get the currently running ``TaskInstance`` object. + +It can be used implicitly, such as with ``**kwargs``, +but can also be used explicitly with ``get_current_context()``. +In this case, the type hint can be used for static analysis. + +.. tab-set:: + + .. tab-item:: @task + :sync: taskflow + + .. exampleinclude:: /../../airflow/example_dags/example_python_context_decorator.py + :language: python + :dedent: 4 + :start-after: [START get_current_context] + :end-before: [END get_current_context] + + .. tab-item:: PythonOperator + :sync: operator + + .. exampleinclude:: /../../airflow/example_dags/example_python_context_operator.py + :language: python + :dedent: 4 + :start-after: [START get_current_context] + :end-before: [END get_current_context] + .. _howto/operator:PythonVirtualenvOperator: PythonVirtualenvOperator @@ -203,6 +234,44 @@ In case you have problems during runtime with broken cached virtual environments Note that any modification of a cached virtual environment (like temp files in binary path, post-installing further requirements) might pollute a cached virtual environment and the operator is not maintaining or cleaning the cache path. +Context +^^^^^^^ + +With some limitations, you can also use ``Context`` in virtual environments. + +.. important:: + Using ``Context`` in a virtual environment is a bit of a challenge + because it involves library dependencies and serialization issues. + + You can bypass this to some extent by using :ref:`Jinja template variables ` and explicitly passing it as a parameter. + + You can also use ``get_current_context()`` in the same way as before, but with some limitations. + + * Requires ``pydantic>=2``. + + * Set ``use_airflow_context`` to ``True`` to call ``get_current_context()`` in the virtual environment. + + * Set ``system_site_packages`` to ``True`` or set ``expect_airflow`` to ``True`` + +.. tab-set:: + + .. tab-item:: @task.virtualenv + :sync: taskflow + + .. exampleinclude:: /../../airflow/example_dags/example_python_context_decorator.py + :language: python + :dedent: 4 + :start-after: [START get_current_context_venv] + :end-before: [END get_current_context_venv] + + .. tab-item:: PythonVirtualenvOperator + :sync: operator + + .. exampleinclude:: /../../airflow/example_dags/example_python_context_operator.py + :language: python + :dedent: 4 + :start-after: [START get_current_context_venv] + :end-before: [END get_current_context_venv] .. _howto/operator:ExternalPythonOperator: @@ -267,6 +336,31 @@ If you want the context related to datetime objects like ``data_interval_start`` If you want to pass variables into the classic :class:`~airflow.operators.python.ExternalPythonOperator` use ``op_args`` and ``op_kwargs``. +Context +^^^^^^^ + +You can use ``Context`` under the same conditions as ``PythonVirtualenvOperator``. + +.. tab-set:: + + .. tab-item:: @task.external_python + :sync: taskflow + + .. exampleinclude:: /../../airflow/example_dags/example_python_context_decorator.py + :language: python + :dedent: 4 + :start-after: [START get_current_context_external] + :end-before: [END get_current_context_external] + + .. tab-item:: ExternalPythonOperator + :sync: operator + + .. exampleinclude:: /../../airflow/example_dags/example_python_context_operator.py + :language: python + :dedent: 4 + :start-after: [START get_current_context_external] + :end-before: [END get_current_context_external] + .. _howto/operator:PythonBranchOperator: PythonBranchOperator diff --git a/newsfragments/41039.feature.rst b/newsfragments/41039.feature.rst new file mode 100644 index 0000000000000..c696d25f874a8 --- /dev/null +++ b/newsfragments/41039.feature.rst @@ -0,0 +1 @@ +Enable ``get_current_context()`` to work in virtual environments. The following ``Operators`` are affected: ``PythonVirtualenvOperator``, ``BranchPythonVirtualenvOperator``, ``ExternalPythonOperator``, ``BranchExternalPythonOperator`` diff --git a/tests/operators/test_python.py b/tests/operators/test_python.py index 993d70cad3340..107adcc12c7d3 100644 --- a/tests/operators/test_python.py +++ b/tests/operators/test_python.py @@ -39,7 +39,11 @@ from slugify import slugify from airflow.decorators import task_group -from airflow.exceptions import AirflowException, DeserializingResultError, RemovedInAirflow3Warning +from airflow.exceptions import ( + AirflowException, + DeserializingResultError, + RemovedInAirflow3Warning, +) from airflow.models.baseoperator import BaseOperator from airflow.models.dag import DAG from airflow.models.taskinstance import TaskInstance, clear_task_instances, set_current_context @@ -56,8 +60,10 @@ _PythonVersionInfo, get_current_context, ) +from airflow.settings import _ENABLE_AIP_44 from airflow.utils import timezone from airflow.utils.context import AirflowContextDeprecationWarning, Context +from airflow.utils.pydantic import is_pydantic_2_installed from airflow.utils.python_virtualenv import prepare_virtualenv from airflow.utils.session import create_session from airflow.utils.state import DagRunState, State, TaskInstanceState @@ -82,6 +88,11 @@ CLOUDPICKLE_INSTALLED = find_spec("cloudpickle") is not None CLOUDPICKLE_MARKER = pytest.mark.skipif(not CLOUDPICKLE_INSTALLED, reason="`cloudpickle` is not installed") +HAS_PYDANTIC_2 = is_pydantic_2_installed() +USE_AIRFLOW_CONTEXT_MARKER = pytest.mark.skipif( + not HAS_PYDANTIC_2 or not _ENABLE_AIP_44, reason="`pydantic<2` or AIP-44 is not enabled" +) + class BasePythonTest: """Base test class for TestPythonOperator and TestPythonSensor classes""" @@ -1005,6 +1016,99 @@ def f(): task = self.run_as_task(f, env_vars={"MY_ENV_VAR": "EFGHI"}, inherit_env=True) assert task.execute_callable() == "EFGHI" + @USE_AIRFLOW_CONTEXT_MARKER + def test_current_context(self): + def f(): + from airflow.operators.python import get_current_context + from airflow.utils.context import Context + + context = get_current_context() + if not isinstance(context, Context): # type: ignore[misc] + error_msg = f"Expected Context, got {type(context)}" + raise TypeError(error_msg) + + return [] + + ti = self.run_as_task(f, return_ti=True, multiple_outputs=False, use_airflow_context=True) + assert ti.state == TaskInstanceState.SUCCESS + + @USE_AIRFLOW_CONTEXT_MARKER + def test_current_context_not_found_error(self): + def f(): + from airflow.operators.python import get_current_context + + get_current_context() + return [] + + with pytest.raises( + AirflowException, + match="Current context was requested but no context was found! " + "Are you running within an airflow task?", + ): + self.run_as_task(f, return_ti=True, multiple_outputs=False, use_airflow_context=False) + + @USE_AIRFLOW_CONTEXT_MARKER + def test_current_context_airflow_not_found_error(self): + airflow_flag: dict[str, bool] = {"expect_airflow": False} + error_msg = "use_airflow_context is set to True, but expect_airflow is set to False." + + if not issubclass(self.opcls, ExternalPythonOperator): + airflow_flag["system_site_packages"] = False + error_msg = "use_airflow_context is set to True, but expect_airflow and system_site_packages are set to False." + + def f(): + from airflow.operators.python import get_current_context + + get_current_context() + return [] + + with pytest.raises(AirflowException, match=error_msg): + self.run_as_task( + f, return_ti=True, multiple_outputs=False, use_airflow_context=True, **airflow_flag + ) + + @USE_AIRFLOW_CONTEXT_MARKER + def test_use_airflow_context_touch_other_variables(self): + def f(): + from airflow.operators.python import get_current_context + from airflow.utils.context import Context + + context = get_current_context() + if not isinstance(context, Context): # type: ignore[misc] + error_msg = f"Expected Context, got {type(context)}" + raise TypeError(error_msg) + + from airflow.operators.python import PythonOperator # noqa: F401 + + return [] + + ti = self.run_as_task(f, return_ti=True, multiple_outputs=False, use_airflow_context=True) + assert ti.state == TaskInstanceState.SUCCESS + + @pytest.mark.skipif(HAS_PYDANTIC_2, reason="`pydantic>=2` is installed") + def test_use_airflow_context_without_pydantic_v2_error(self): + def f(): + from airflow.operators.python import get_current_context + + get_current_context() + return [] + + error_msg = "`get_current_context()` needs to be used with Pydantic 2 and AIP-44 enabled." + with pytest.raises(AirflowException, match=re.escape(error_msg)): + self.run_as_task(f, return_ti=True, multiple_outputs=False, use_airflow_context=True) + + @pytest.mark.skipif(_ENABLE_AIP_44, reason="AIP-44 is enabled") + def test_use_airflow_context_without_aip_44_error(self): + def f(): + from airflow.operators.python import get_current_context + + get_current_context() + return [] + + error_msg = "`get_current_context()` needs to be used with Pydantic 2 and AIP-44 enabled." + with pytest.raises(AirflowException, match=re.escape(error_msg)): + self.run_as_task(f, return_ti=True, multiple_outputs=False, use_airflow_context=True) + venv_cache_path = tempfile.mkdtemp(prefix="venv_cache_path") @@ -1426,6 +1530,30 @@ def f( self.run_as_task(f, serializer=serializer, system_site_packages=False, requirements=None) + @USE_AIRFLOW_CONTEXT_MARKER + def test_current_context_system_site_packages(self, session): + def f(): + from airflow.operators.python import get_current_context + from airflow.utils.context import Context + + context = get_current_context() + if not isinstance(context, Context): # type: ignore[misc] + error_msg = f"Expected Context, got {type(context)}" + raise TypeError(error_msg) + + return [] + + ti = self.run_as_task( + f, + return_ti=True, + multiple_outputs=False, + use_airflow_context=True, + session=session, + expect_airflow=False, + system_site_packages=True, + ) + assert ti.state == TaskInstanceState.SUCCESS + # when venv tests are run in parallel to other test they create new processes and this might take # quite some time in shared docker environment and get some contention even between different containers @@ -1745,6 +1873,30 @@ def default_kwargs(*, python_version=DEFAULT_PYTHON_VERSION, **kwargs): kwargs["venv_cache_path"] = venv_cache_path return kwargs + @USE_AIRFLOW_CONTEXT_MARKER + def test_current_context_system_site_packages(self, session): + def f(): + from airflow.operators.python import get_current_context + from airflow.utils.context import Context + + context = get_current_context() + if not isinstance(context, Context): # type: ignore[misc] + error_msg = f"Expected Context, got {type(context)}" + raise TypeError(error_msg) + + return [] + + ti = self.run_as_task( + f, + return_ti=True, + multiple_outputs=False, + use_airflow_context=True, + session=session, + expect_airflow=False, + system_site_packages=True, + ) + assert ti.state == TaskInstanceState.SUCCESS + # when venv tests are run in parallel to other test they create new processes and this might take # quite some time in shared docker environment and get some contention even between different containers