Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
8aad6cb
temp: ignore breeze hook
phi-friday Jul 26, 2024
a36b4de
feat: context_to_json
phi-friday Jul 26, 2024
7be9199
fix: rm select
phi-friday Jul 26, 2024
214e696
fix: update jinja template
phi-friday Jul 26, 2024
f494313
fix: apply use_airflow_context
phi-friday Jul 26, 2024
8ab5670
feat: add use_airflow_context
phi-friday Jul 26, 2024
bb8bc6b
fix
phi-friday Jul 26, 2024
8924262
fix: use _DEPRECATION_REPLACEMENTS
phi-friday Jul 26, 2024
d59319e
test: dump_airflow_context
phi-friday Jul 26, 2024
5b5e40a
test: add venv operator tests
phi-friday Jul 26, 2024
8f3c8f3
Revert "temp: ignore breeze hook"
phi-friday Jul 26, 2024
d3030cc
fix: mark db_test
phi-friday Jul 26, 2024
59a78c4
fix: review
phi-friday Aug 3, 2024
287384c
tests
phi-friday Aug 3, 2024
50e0cea
branch tests
phi-friday Aug 3, 2024
e543258
prepare docs
phi-friday Aug 3, 2024
9d9acf3
update docs
phi-friday Aug 3, 2024
3163869
fix
phi-friday Aug 3, 2024
2210b33
update docs
phi-friday Aug 3, 2024
de49975
fix: static check errors
phi-friday Aug 3, 2024
a58e585
chore: add newsfragment
phi-friday Aug 3, 2024
5342b24
fix: static check
phi-friday Aug 3, 2024
ab93edb
fix: review
phi-friday Aug 4, 2024
1824934
fix: review
phi-friday Aug 5, 2024
27e1abe
fix: convert ti to simple ti
phi-friday Aug 5, 2024
2d1a3c4
fix: test errors
phi-friday Aug 5, 2024
ac0087d
fix: docs
phi-friday Aug 5, 2024
b483167
Revert "fix: convert ti to simple ti"
phi-friday Aug 5, 2024
ec836ce
fix: use BaseSerialization
phi-friday Aug 5, 2024
0d58509
fix: merge error
phi-friday Aug 5, 2024
27c1040
fix: ti state error
phi-friday Aug 5, 2024
72caa72
fix: static error
phi-friday Aug 5, 2024
db98ffa
Update docs/apache-airflow/howto/operator/python.rst
phi-friday Aug 8, 2024
7f2e0a4
fix: json.load -> BaseSerialization.deserialize
phi-friday Aug 8, 2024
7d26f20
fix: add airflow conditions
phi-friday Aug 8, 2024
ccd6052
fix: json error
phi-friday Aug 8, 2024
a16737e
fix: use pickling_library
phi-friday Aug 8, 2024
7f121ae
Revert "fix: use pickling_library"
phi-friday Aug 8, 2024
1a6d012
fix: use_pydantic_models=True
phi-friday Aug 8, 2024
89bc4a1
tests
phi-friday Aug 8, 2024
f0c2372
fix: static errors
phi-friday Aug 8, 2024
98d6845
fix: update docs
phi-friday Aug 8, 2024
d33e683
fix: update newsfragment
phi-friday Aug 8, 2024
22f7709
fix: BranchMixIn
phi-friday Aug 8, 2024
0bc3ad2
fix: rm comment
phi-friday Aug 9, 2024
f5ec8f9
fix: need pydantic@v2 & aip44
phi-friday Aug 12, 2024
5f9af9a
fix: raise error
phi-friday Aug 12, 2024
916282c
fix: USE_AIRFLOW_CONTEXT_MARKER
phi-friday Aug 12, 2024
ff83290
tests: error test
phi-friday Aug 12, 2024
5ac8d5f
docs: add pydantic@v2
phi-friday Aug 12, 2024
c37e6a4
fix: regex escape
phi-friday Aug 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions airflow/decorators/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ class TaskDecoratorCollection:
env_vars: dict[str, str] | None = None,
inherit_env: bool = True,
use_dill: bool = False,
use_airflow_context: bool = False,
**kwargs,
) -> TaskDecorator:
"""Create a decorator to convert the decorated callable to a virtual environment task.
Expand Down Expand Up @@ -176,6 +177,7 @@ class TaskDecoratorCollection:
:param use_dill: Deprecated, use ``serializer`` instead. Whether to use dill to serialize
the args and result (pickle is default). This allows more complex types
but requires you to include dill in your requirements.
:param use_airflow_context: Whether to provide ``get_current_context()`` to the python_callable.
"""
@overload
def virtualenv(self, python_callable: Callable[FParams, FReturn]) -> Task[FParams, FReturn]: ...
Expand All @@ -192,6 +194,7 @@ class TaskDecoratorCollection:
env_vars: dict[str, str] | None = None,
inherit_env: bool = True,
use_dill: bool = False,
use_airflow_context: bool = False,
**kwargs,
) -> TaskDecorator:
"""Create a decorator to convert the decorated callable to a virtual environment task.
Expand Down Expand Up @@ -225,6 +228,7 @@ class TaskDecoratorCollection:
:param use_dill: Deprecated, use ``serializer`` instead. Whether to use dill to serialize
the args and result (pickle is default). This allows more complex types
but requires you to include dill in your requirements.
:param use_airflow_context: Whether to provide ``get_current_context()`` to the python_callable.
"""
@overload
def branch( # type: ignore[misc]
Expand Down Expand Up @@ -258,6 +262,7 @@ class TaskDecoratorCollection:
venv_cache_path: None | str = None,
show_return_value_in_logs: bool = True,
use_dill: bool = False,
use_airflow_context: bool = False,
**kwargs,
) -> TaskDecorator:
"""Create a decorator to wrap the decorated callable into a BranchPythonVirtualenvOperator.
Expand Down Expand Up @@ -299,6 +304,7 @@ class TaskDecoratorCollection:
:param use_dill: Deprecated, use ``serializer`` instead. Whether to use dill to serialize
the args and result (pickle is default). This allows more complex types
but requires you to include dill in your requirements.
:param use_airflow_context: Whether to provide ``get_current_context()`` to the python_callable.
"""
@overload
def branch_virtualenv(self, python_callable: Callable[FParams, FReturn]) -> Task[FParams, FReturn]: ...
Expand Down
92 changes: 92 additions & 0 deletions airflow/example_dags/example_python_context_decorator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Example DAG demonstrating the usage of the PythonOperator with `get_current_context()` to get the current context.

Also, demonstrates the usage of the TaskFlow API.
"""

from __future__ import annotations

import sys

import pendulum

from airflow.decorators import dag, task

SOME_EXTERNAL_PYTHON = sys.executable


@dag(
schedule=None,
start_date=pendulum.datetime(2021, 1, 1, tz="UTC"),
catchup=False,
tags=["example"],
)
def example_python_context_decorator():
# [START get_current_context]
@task(task_id="print_the_context")
def print_context() -> str:
"""Print the Airflow context."""
from pprint import pprint

from airflow.operators.python import get_current_context

context = get_current_context()
pprint(context)
return "Whatever you return gets printed in the logs"

print_the_context = print_context()
# [END get_current_context]

# [START get_current_context_venv]
@task.virtualenv(task_id="print_the_context_venv", use_airflow_context=True)
def print_context_venv() -> str:
"""Print the Airflow context in venv."""
from pprint import pprint

from airflow.operators.python import get_current_context

context = get_current_context()
pprint(context)
return "Whatever you return gets printed in the logs"

print_the_context_venv = print_context_venv()
# [END get_current_context_venv]

# [START get_current_context_external]
@task.external_python(
task_id="print_the_context_external", python=SOME_EXTERNAL_PYTHON, use_airflow_context=True
)
def print_context_external() -> str:
"""Print the Airflow context in external python."""
from pprint import pprint

from airflow.operators.python import get_current_context

context = get_current_context()
pprint(context)
return "Whatever you return gets printed in the logs"

print_the_context_external = print_context_external()
# [END get_current_context_external]

_ = print_the_context >> [print_the_context_venv, print_the_context_external]


example_python_context_decorator()
91 changes: 91 additions & 0 deletions airflow/example_dags/example_python_context_operator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Example DAG demonstrating the usage of the PythonOperator with `get_current_context()` to get the current context.

Also, demonstrates the usage of the classic Python operators.
"""

from __future__ import annotations

import sys

import pendulum

from airflow import DAG
from airflow.operators.python import ExternalPythonOperator, PythonOperator, PythonVirtualenvOperator

SOME_EXTERNAL_PYTHON = sys.executable

with DAG(
dag_id="example_python_context_operator",
schedule=None,
start_date=pendulum.datetime(2021, 1, 1, tz="UTC"),
catchup=False,
tags=["example"],
) as dag:
# [START get_current_context]
def print_context() -> str:
"""Print the Airflow context."""
from pprint import pprint

from airflow.operators.python import get_current_context

context = get_current_context()
pprint(context)
return "Whatever you return gets printed in the logs"

print_the_context = PythonOperator(task_id="print_the_context", python_callable=print_context)
# [END get_current_context]

# [START get_current_context_venv]
def print_context_venv() -> str:
"""Print the Airflow context in venv."""
from pprint import pprint

from airflow.operators.python import get_current_context

context = get_current_context()
pprint(context)
return "Whatever you return gets printed in the logs"

print_the_context_venv = PythonVirtualenvOperator(
task_id="print_the_context_venv", python_callable=print_context_venv, use_airflow_context=True
)
# [END get_current_context_venv]

# [START get_current_context_external]
def print_context_external() -> str:
"""Print the Airflow context in external python."""
from pprint import pprint

from airflow.operators.python import get_current_context

context = get_current_context()
pprint(context)
return "Whatever you return gets printed in the logs"

print_the_context_external = ExternalPythonOperator(
task_id="print_the_context_external",
python_callable=print_context_external,
python=SOME_EXTERNAL_PYTHON,
use_airflow_context=True,
)
# [END get_current_context_external]

_ = print_the_context >> [print_the_context_venv, print_the_context_external]
39 changes: 39 additions & 0 deletions airflow/operators/python.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,19 +49,23 @@
from airflow.models.taskinstance import _CURRENT_CONTEXT
from airflow.models.variable import Variable
from airflow.operators.branch import BranchMixIn
from airflow.settings import _ENABLE_AIP_44
from airflow.typing_compat import Literal
from airflow.utils import hashlib_wrapper
from airflow.utils.context import context_copy_partial, context_get_outlet_events, context_merge
from airflow.utils.file import get_unique_dag_module_name
from airflow.utils.operator_helpers import ExecutionCallableRunner, KeywordParameters
from airflow.utils.process_utils import execute_in_subprocess
from airflow.utils.pydantic import is_pydantic_2_installed
from airflow.utils.python_virtualenv import prepare_virtualenv, write_python_script
from airflow.utils.session import create_session

log = logging.getLogger(__name__)

if TYPE_CHECKING:
from pendulum.datetime import DateTime

from airflow.serialization.enums import Encoding
from airflow.utils.context import Context


Expand Down Expand Up @@ -442,6 +446,7 @@ def __init__(
env_vars: dict[str, str] | None = None,
inherit_env: bool = True,
use_dill: bool = False,
use_airflow_context: bool = False,
**kwargs,
):
if (
Expand Down Expand Up @@ -481,6 +486,7 @@ def __init__(
f"Expected one of {', '.join(map(repr, _SERIALIZERS))}"
)
raise AirflowException(msg)

self.pickling_library = _SERIALIZERS[serializer]
self.serializer: _SerializerTypeDef = serializer

Expand All @@ -494,6 +500,7 @@ def __init__(
)
self.env_vars = env_vars
self.inherit_env = inherit_env
self.use_airflow_context = use_airflow_context

@abstractmethod
def _iter_serializable_context_keys(self):
Expand Down Expand Up @@ -540,17 +547,23 @@ def _execute_python_callable_in_subprocess(self, python_path: Path):
string_args_path = tmp_dir / "string_args.txt"
script_path = tmp_dir / "script.py"
termination_log_path = tmp_dir / "termination.log"
airflow_context_path = tmp_dir / "airflow_context.json"

self._write_args(input_path)
self._write_string_args(string_args_path)

if self.use_airflow_context and (not is_pydantic_2_installed() or not _ENABLE_AIP_44):
error_msg = "`get_current_context()` needs to be used with Pydantic 2 and AIP-44 enabled."
raise AirflowException(error_msg)

jinja_context = {
"op_args": self.op_args,
"op_kwargs": op_kwargs,
"expect_airflow": self.expect_airflow,
"pickling_library": self.serializer,
"python_callable": self.python_callable.__name__,
"python_callable_source": self.get_python_source(),
"use_airflow_context": self.use_airflow_context,
}

if inspect.getfile(self.python_callable) == self.dag.fileloc:
Expand All @@ -561,6 +574,19 @@ def _execute_python_callable_in_subprocess(self, python_path: Path):
filename=os.fspath(script_path),
render_template_as_native_obj=self.dag.render_template_as_native_obj,
)
if self.use_airflow_context:
from airflow.serialization.serialized_objects import BaseSerialization

context = get_current_context()
with create_session() as session:
# FIXME: DetachedInstanceError
dag_run, task_instance = context["dag_run"], context["task_instance"]
session.add_all([dag_run, task_instance])
serializable_context: dict[Encoding, Any] = BaseSerialization.serialize(
context, use_pydantic_models=True
)
with airflow_context_path.open("w+") as file:
json.dump(serializable_context, file)

env_vars = dict(os.environ) if self.inherit_env else {}
if self.env_vars:
Expand All @@ -575,6 +601,7 @@ def _execute_python_callable_in_subprocess(self, python_path: Path):
os.fspath(output_path),
os.fspath(string_args_path),
os.fspath(termination_log_path),
os.fspath(airflow_context_path),
],
env=env_vars,
)
Expand Down Expand Up @@ -666,6 +693,7 @@ class PythonVirtualenvOperator(_BasePythonVirtualenvOperator):
:param use_dill: Deprecated, use ``serializer`` instead. Whether to use dill to serialize
the args and result (pickle is default). This allows more complex types
but requires you to include dill in your requirements.
:param use_airflow_context: Whether to provide ``get_current_context()`` to the python_callable.
"""

template_fields: Sequence[str] = tuple(
Expand Down Expand Up @@ -694,6 +722,7 @@ def __init__(
env_vars: dict[str, str] | None = None,
inherit_env: bool = True,
use_dill: bool = False,
use_airflow_context: bool = False,
**kwargs,
):
if (
Expand All @@ -715,6 +744,9 @@ def __init__(
)
if not is_venv_installed():
raise AirflowException("PythonVirtualenvOperator requires virtualenv, please install it.")
if use_airflow_context and (not expect_airflow and not system_site_packages):
error_msg = "use_airflow_context is set to True, but expect_airflow and system_site_packages are set to False."
raise AirflowException(error_msg)
if not requirements:
self.requirements: list[str] = []
elif isinstance(requirements, str):
Expand Down Expand Up @@ -744,6 +776,7 @@ def __init__(
env_vars=env_vars,
inherit_env=inherit_env,
use_dill=use_dill,
use_airflow_context=use_airflow_context,
**kwargs,
)

Expand Down Expand Up @@ -962,6 +995,7 @@ class ExternalPythonOperator(_BasePythonVirtualenvOperator):
:param use_dill: Deprecated, use ``serializer`` instead. Whether to use dill to serialize
the args and result (pickle is default). This allows more complex types
but requires you to include dill in your requirements.
:param use_airflow_context: Whether to provide ``get_current_context()`` to the python_callable.
"""

template_fields: Sequence[str] = tuple({"python"}.union(PythonOperator.template_fields))
Expand All @@ -983,10 +1017,14 @@ def __init__(
env_vars: dict[str, str] | None = None,
inherit_env: bool = True,
use_dill: bool = False,
use_airflow_context: bool = False,
**kwargs,
):
if not python:
raise ValueError("Python Path must be defined in ExternalPythonOperator")
if use_airflow_context and not expect_airflow:
error_msg = "use_airflow_context is set to True, but expect_airflow is set to False."
raise AirflowException(error_msg)
self.python = python
self.expect_pendulum = expect_pendulum
super().__init__(
Expand All @@ -1002,6 +1040,7 @@ def __init__(
env_vars=env_vars,
inherit_env=inherit_env,
use_dill=use_dill,
use_airflow_context=use_airflow_context,
**kwargs,
)

Expand Down
Loading