Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion airflow/api_fastapi/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def create_app(apps: str = "all") -> FastAPI:
init_middlewares(app)

if "execution" in apps_list or "all" in apps_list:
task_exec_api_app = create_task_execution_api_app(app)
task_exec_api_app = create_task_execution_api_app()
init_error_handlers(task_exec_api_app)
app.mount("/execution", task_exec_api_app)

Expand Down
42 changes: 41 additions & 1 deletion airflow/api_fastapi/execution_api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,16 @@
from __future__ import annotations

from contextlib import asynccontextmanager
from functools import cached_property
from typing import TYPE_CHECKING

import attrs
from fastapi import FastAPI
from fastapi.openapi.utils import get_openapi

if TYPE_CHECKING:
import httpx


@asynccontextmanager
async def lifespan(app: FastAPI):
Expand All @@ -30,7 +36,7 @@ async def lifespan(app: FastAPI):
yield


def create_task_execution_api_app(app: FastAPI) -> FastAPI:
def create_task_execution_api_app() -> FastAPI:
"""Create FastAPI app for task execution API."""
from airflow.api_fastapi.execution_api.routes import execution_api_router

Expand Down Expand Up @@ -88,3 +94,37 @@ def get_extra_schemas() -> dict[str, dict]:
# as that has different payload requirements
"TerminalTIState": {"type": "string", "enum": list(TerminalTIState)},
}


@attrs.define()
class InProcessExecuctionAPI:
"""
A helper class to make it possible to run the ExecutionAPI "in-process".

The sync version of this makes use of a2wsgi which runs the async loop in a separate thread. This is
needed so that we can use the sync httpx client
"""

_app: FastAPI | None = None

@cached_property
def app(self):
if not self._app:
from airflow.api_fastapi.execution_api.app import create_task_execution_api_app

self._app = create_task_execution_api_app()

return self._app

@cached_property
def transport(self) -> httpx.WSGITransport:
import httpx
from a2wsgi import ASGIMiddleware

return httpx.WSGITransport(app=ASGIMiddleware(self.app)) # type: ignore[arg-type]

@cached_property
def atransport(self) -> httpx.ASGITransport:
import httpx

return httpx.ASGITransport(app=self.app)
64 changes: 54 additions & 10 deletions airflow/dag_processing/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
from __future__ import annotations

import functools
import os
import sys
import traceback
Expand All @@ -32,17 +33,29 @@
)
from airflow.configuration import conf
from airflow.models.dagbag import DagBag
from airflow.sdk.execution_time.comms import GetConnection, GetVariable
from airflow.sdk.execution_time.comms import ConnectionResult, GetConnection, GetVariable, VariableResult
from airflow.sdk.execution_time.supervisor import WatchedSubprocess
from airflow.serialization.serialized_objects import LazyDeserializedDAG, SerializedDAG
from airflow.stats import Stats

if TYPE_CHECKING:
from structlog.typing import FilteringBoundLogger

from airflow.api_fastapi.execution_api.app import InProcessExecuctionAPI
from airflow.sdk.api.client import Client
from airflow.sdk.definitions.context import Context
from airflow.typing_compat import Self

ToManager = Annotated[
Union["DagFileParsingResult", GetConnection, GetVariable],
Field(discriminator="type"),
]

ToDagProcessor = Annotated[
Union["DagFileParseRequest", ConnectionResult, VariableResult],
Field(discriminator="type"),
]


def _parse_file_entrypoint():
import os
Expand All @@ -51,19 +64,24 @@ def _parse_file_entrypoint():

from airflow.sdk.execution_time import task_runner
from airflow.settings import configure_orm

# Parse DAG file, send JSON back up!

# We need to reconfigure the orm here, as DagFileProcessorManager does db queries for bundles, and
# the session across forks blows things up.
configure_orm()

comms_decoder = task_runner.CommsDecoder[DagFileParseRequest, DagFileParsingResult](
comms_decoder = task_runner.CommsDecoder[ToDagProcessor, ToManager](
input=sys.stdin,
decoder=TypeAdapter[DagFileParseRequest](DagFileParseRequest),
decoder=TypeAdapter[ToDagProcessor](ToDagProcessor),
)

msg = comms_decoder.get_message()
if not isinstance(msg, DagFileParseRequest):
raise RuntimeError(f"Required first message to be a DagFileParseRequest, it was {msg}")
comms_decoder.request_socket = os.fdopen(msg.requests_fd, "wb", buffering=0)

task_runner.SUPERVISOR_COMMS = comms_decoder
log = structlog.get_logger(logger_name="task")

result = _parse_file(msg, log)
Expand Down Expand Up @@ -188,10 +206,12 @@ class DagFileParsingResult(BaseModel):
type: Literal["DagFileParsingResult"] = "DagFileParsingResult"


ToParent = Annotated[
Union[DagFileParsingResult, GetConnection, GetVariable],
Field(discriminator="type"),
]
@functools.cache
def in_process_api_server() -> InProcessExecuctionAPI:
from airflow.api_fastapi.execution_api.app import InProcessExecuctionAPI

api = InProcessExecuctionAPI()
return api


@attrs.define(kw_only=True)
Expand All @@ -207,7 +227,7 @@ class DagFileProcessorProcess(WatchedSubprocess):
"""

parsing_result: DagFileParsingResult | None = None
decoder: ClassVar[TypeAdapter[ToParent]] = TypeAdapter[ToParent](ToParent)
decoder: ClassVar[TypeAdapter[ToManager]] = TypeAdapter[ToManager](ToManager)

@classmethod
def start( # type: ignore[override]
Expand Down Expand Up @@ -237,12 +257,36 @@ def _on_child_started(
)
self.stdin.write(msg.model_dump_json().encode() + b"\n")

def _handle_request(self, msg: ToParent, log: FilteringBoundLogger) -> None: # type: ignore[override]
# TODO: GetVariable etc -- parsing a dag can run top level code that asks for an Airflow Variable
@functools.cached_property
def client(self) -> Client:
from airflow.sdk.api.client import Client

client = Client(base_url=None, token="", dry_run=True, transport=in_process_api_server().transport)
# Mypy is wrong -- the setter accepts a string on the property setter! `URLType = URL | str`
client.base_url = "http://in-process.invalid./" # type: ignore[assignment]
return client

def _handle_request(self, msg: ToManager, log: FilteringBoundLogger) -> None: # type: ignore[override]
from airflow.sdk.api.datamodels._generated import ConnectionResponse, VariableResponse

resp = None
if isinstance(msg, DagFileParsingResult):
self.parsing_result = msg
return
elif isinstance(msg, GetConnection):
conn = self.client.connections.get(msg.conn_id)
if isinstance(conn, ConnectionResponse):
conn_result = ConnectionResult.from_conn_response(conn)
resp = conn_result.model_dump_json(exclude_unset=True).encode()
else:
resp = conn.model_dump_json().encode()
elif isinstance(msg, GetVariable):
var = self.client.variables.get(msg.key)
if isinstance(var, VariableResponse):
var_result = VariableResult.from_variable_response(var)
resp = var_result.model_dump_json(exclude_unset=True).encode()
else:
resp = var.model_dump_json().encode()
else:
log.error("Unhandled request", msg=msg)
return
Expand Down
24 changes: 24 additions & 0 deletions airflow/models/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

import json
import logging
import sys
import warnings
from typing import TYPE_CHECKING, Any

from sqlalchemy import Boolean, Column, Integer, String, Text, delete, select
Expand Down Expand Up @@ -136,6 +138,28 @@ def get(
:param default_var: Default value of the Variable if the Variable doesn't exist
:param deserialize_json: Deserialize the value to a Python dict
"""
# TODO: This is not the best way of having compat, but it's "better than erroring" for now. This still
# means SQLA etc is loaded, but we can't avoid that unless/until we add import shims as a big
# back-compat layer

# If this is set it means are in some kind of execution context (Task, Dag Parse or Triggerer perhaps)
# and should use the Task SDK API server path
if hasattr(sys.modules.get("airflow.sdk.execution_time.task_runner"), "SUPERVISOR_COMMS"):
warnings.warn(
"Using Variable.get from `airflow.models` is deprecated. Please use `from airflow.sdk import"
"Variable` instead",
DeprecationWarning,
stacklevel=1,
)
from airflow.sdk import Variable as TaskSDKVariable
from airflow.sdk.definitions._internal.types import NOTSET

return TaskSDKVariable.get(
key,
default=NOTSET if default_var is cls.__NO_DEFAULT_SENTINEL else default_var,
deserialize_json=deserialize_json,
)

var_val = Variable.get_variable_from_secrets(key=key)
if var_val is None:
if default_var is not cls.__NO_DEFAULT_SENTINEL:
Expand Down
1 change: 1 addition & 0 deletions hatch_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,7 @@
}

DEPENDENCIES = [
"a2wsgi>=1.10.8",
# Alembic is important to handle our migrations in predictable and performant way. It is developed
# together with SQLAlchemy. Our experience with Alembic is that it very stable in minor version
# The 1.13.0 of alembic marked some migration code as SQLAlchemy 2+ only so we limit it to 1.13.1
Expand Down
2 changes: 2 additions & 0 deletions task_sdk/src/airflow/sdk/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
"Label",
"MappedOperator",
"TaskGroup",
"Variable",
"XComArg",
"dag",
"get_current_context",
Expand All @@ -46,6 +47,7 @@
from airflow.sdk.definitions.edges import EdgeModifier, Label
from airflow.sdk.definitions.mappedoperator import MappedOperator
from airflow.sdk.definitions.taskgroup import TaskGroup
from airflow.sdk.definitions.variable import Variable
from airflow.sdk.definitions.xcom_arg import XComArg

__lazy_imports: dict[str, str] = {
Expand Down
4 changes: 2 additions & 2 deletions task_sdk/src/airflow/sdk/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,8 +384,8 @@ def __init__(self, *, base_url: str | None, dry_run: bool = False, token: str, *
if dry_run:
# If dry run is requested, install a no op handler so that simple tasks can "heartbeat" using a
# real client, but just don't make any HTTP requests
kwargs["transport"] = httpx.MockTransport(noop_handler)
kwargs["base_url"] = "dry-run://server"
kwargs.setdefault("transport", httpx.MockTransport(noop_handler))
kwargs.setdefault("base_url", "dry-run://server")
else:
kwargs["base_url"] = base_url
pyver = f"{'.'.join(map(str, sys.version_info[:3]))}"
Expand Down
13 changes: 13 additions & 0 deletions task_sdk/src/airflow/sdk/definitions/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

import attrs

from airflow.sdk.definitions._internal.types import NOTSET


@attrs.define
class Variable:
Expand All @@ -39,3 +41,14 @@ class Variable:
description: str | None = None

# TODO: Extend this definition for reading/writing variables without context
@classmethod
def get(cls, key: str, default: Any = NOTSET, deserialize_json: bool = False):
from airflow.sdk.exceptions import AirflowRuntimeError, ErrorType
from airflow.sdk.execution_time.context import _get_variable

try:
return _get_variable(key, deserialize_json=deserialize_json).value
except AirflowRuntimeError as e:
if e.error.error == ErrorType.VARIABLE_NOT_FOUND and default is not NOTSET:
return default
raise
6 changes: 3 additions & 3 deletions tests/api_fastapi/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,10 @@ def test_core_api_app(
def test_execution_api_app(
mock_create_task_exec_api, mock_init_plugins, mock_init_views, mock_init_dag_bag, client
):
test_app = client(apps="execution").app
client(apps="execution")

# Assert that execution-related functions were called
mock_create_task_exec_api.assert_called_once_with(test_app)
mock_create_task_exec_api.assert_called_once()

# Assert that core-related functions were NOT called
mock_init_dag_bag.assert_not_called()
Expand Down Expand Up @@ -91,4 +91,4 @@ def test_all_apps(mock_create_task_exec_api, mock_init_plugins, mock_init_views,
mock_init_plugins.assert_called_once_with(test_app)

# Assert that execution-related functions were also called
mock_create_task_exec_api.assert_called_once_with(test_app)
mock_create_task_exec_api.assert_called_once_with()
Loading
Loading