Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions airflow-core/src/airflow/dag_processing/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,18 +183,19 @@ def _parse_file_entrypoint():

import structlog

from airflow.sdk.execution_time import comms, task_runner
from airflow.sdk.execution_time.comms import CommsDecoder
from airflow.sdk.execution_time.task_runner import SupervisorComms

# Parse DAG file, send JSON back up!
comms_decoder = comms.CommsDecoder[ToDagProcessor, ToManager](
comms_decoder = CommsDecoder[ToDagProcessor, ToManager](
body_decoder=TypeAdapter[ToDagProcessor](ToDagProcessor),
)

msg = comms_decoder._get_response()
if not isinstance(msg, DagFileParseRequest):
raise RuntimeError(f"Required first message to be a DagFileParseRequest, it was {msg}")

task_runner.SUPERVISOR_COMMS = comms_decoder
SupervisorComms().set_comms(comms_decoder)
log = structlog.get_logger(logger_name="task")

result = _parse_file(msg, log)
Expand Down
6 changes: 3 additions & 3 deletions airflow-core/src/airflow/jobs/triggerer_job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -930,10 +930,10 @@ async def init_comms(self):
"""
Set up the communications pipe between this process and the supervisor.

This also sets up the SUPERVISOR_COMMS so that TaskSDK code can work as expected too (but that will
This also sets up the supervisor-comms so that TaskSDK code can work as expected too (but that will
need to be wrapped in an ``sync_to_async()`` call)
"""
from airflow.sdk.execution_time import task_runner
from airflow.sdk.execution_time.task_runner import SupervisorComms

# Yes, we read and write to stdin! It's a socket, not a normal stdin.
reader, writer = await asyncio.open_connection(sock=socket(fileno=0))
Expand All @@ -943,7 +943,7 @@ async def init_comms(self):
async_reader=reader,
)

task_runner.SUPERVISOR_COMMS = self.comms_decoder
SupervisorComms().set_comms(self.comms_decoder)

msg = await self.comms_decoder._aget_response(expect_id=0)

Expand Down
7 changes: 7 additions & 0 deletions airflow-core/src/airflow/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# under the License.
from __future__ import annotations

import os
from typing import TYPE_CHECKING, Any

from sqlalchemy import Integer, MetaData, String, text
Expand Down Expand Up @@ -98,3 +99,9 @@ class TaskInstanceDependencies(Base):
dag_id: Mapped[str] = mapped_column(StringID(), nullable=False)
run_id: Mapped[str] = mapped_column(StringID(), nullable=False)
map_index: Mapped[int] = mapped_column(Integer, nullable=False, server_default=text("-1"))


def is_client_process_context() -> bool:
"""Check if we are in an execution context (Task, Dag Parser or Triggerer perhaps)."""
process_context = os.environ.get("_AIRFLOW_PROCESS_CONTEXT", "").lower()
return process_context == "client"
13 changes: 3 additions & 10 deletions airflow-core/src/airflow/models/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import json
import logging
import re
import sys
import warnings
from contextlib import suppress
from json import JSONDecodeError
Expand All @@ -34,7 +33,7 @@
from airflow._shared.secrets_masker import mask_secret
from airflow.configuration import conf, ensure_secrets_loaded
from airflow.exceptions import AirflowException, AirflowNotFoundException
from airflow.models.base import ID_LEN, Base
from airflow.models.base import ID_LEN, Base, is_client_process_context
from airflow.models.crypto import get_fernet
from airflow.utils.helpers import prune_dict
from airflow.utils.log.logging_mixin import LoggingMixin
Expand Down Expand Up @@ -500,13 +499,7 @@ def get_connection_from_secrets(cls, conn_id: str, team_name: str | None = None)
:param team_name: Team name associated to the task trying to access the connection (if any)
:return: connection
"""
# TODO: This is not the best way of having compat, but it's "better than erroring" for now. This still
# means SQLA etc is loaded, but we can't avoid that unless/until we add import shims as a big
# back-compat layer

# If this is set it means are in some kind of execution context (Task, Dag Parse or Triggerer perhaps)
# and should use the Task SDK API server path
if hasattr(sys.modules.get("airflow.sdk.execution_time.task_runner"), "SUPERVISOR_COMMS"):
if is_client_process_context():
from airflow.sdk import Connection as TaskSDKConnection
from airflow.sdk.exceptions import AirflowRuntimeError, ErrorType

Expand Down Expand Up @@ -590,7 +583,7 @@ def to_dict(self, *, prune_empty: bool = False, validate: bool = True) -> dict[s

@classmethod
def from_json(cls, value, conn_id=None) -> Connection:
if hasattr(sys.modules.get("airflow.sdk.execution_time.task_runner"), "SUPERVISOR_COMMS"):
if is_client_process_context():
from airflow.sdk import Connection as TaskSDKConnection

warnings.warn(
Expand Down
35 changes: 5 additions & 30 deletions airflow-core/src/airflow/models/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import contextlib
import json
import logging
import sys
import warnings
from typing import TYPE_CHECKING, Any

Expand All @@ -30,7 +29,7 @@

from airflow._shared.secrets_masker import mask_secret
from airflow.configuration import conf, ensure_secrets_loaded
from airflow.models.base import ID_LEN, Base
from airflow.models.base import ID_LEN, Base, is_client_process_context
from airflow.models.crypto import get_fernet
from airflow.secrets.metastore import MetastoreBackend
from airflow.utils.log.logging_mixin import LoggingMixin
Expand Down Expand Up @@ -148,13 +147,7 @@ def get(
:param deserialize_json: Deserialize the value to a Python dict
:param team_name: Team name associated to the task trying to access the variable (if any)
"""
# TODO: This is not the best way of having compat, but it's "better than erroring" for now. This still
# means SQLA etc is loaded, but we can't avoid that unless/until we add import shims as a big
# back-compat layer

# If this is set it means we are in some kind of execution context (Task, Dag Parse or Triggerer perhaps)
# and should use the Task SDK API server path
if hasattr(sys.modules.get("airflow.sdk.execution_time.task_runner"), "SUPERVISOR_COMMS"):
if is_client_process_context():
warnings.warn(
"Using Variable.get from `airflow.models` is deprecated."
"Please use `get` on Variable from sdk(`airflow.sdk.Variable`) instead",
Expand Down Expand Up @@ -208,13 +201,7 @@ def set(
:param team_name: Team name associated to the variable (if any)
:param session: optional session, use if provided or create a new one
"""
# TODO: This is not the best way of having compat, but it's "better than erroring" for now. This still
# means SQLA etc is loaded, but we can't avoid that unless/until we add import shims as a big
# back-compat layer

# If this is set it means we are in some kind of execution context (Task, Dag Parse or Triggerer perhaps)
# and should use the Task SDK API server path
if hasattr(sys.modules.get("airflow.sdk.execution_time.task_runner"), "SUPERVISOR_COMMS"):
if is_client_process_context():
warnings.warn(
"Using Variable.set from `airflow.models` is deprecated."
"Please use `set` on Variable from sdk(`airflow.sdk.Variable`) instead",
Expand Down Expand Up @@ -339,13 +326,7 @@ def update(
:param team_name: Team name associated to the variable (if any)
:param session: optional session, use if provided or create a new one
"""
# TODO: This is not the best way of having compat, but it's "better than erroring" for now. This still
# means SQLA etc is loaded, but we can't avoid that unless/until we add import shims as a big
# back-compat layer

# If this is set it means are in some kind of execution context (Task, Dag Parse or Triggerer perhaps)
# and should use the Task SDK API server path
if hasattr(sys.modules.get("airflow.sdk.execution_time.task_runner"), "SUPERVISOR_COMMS"):
if is_client_process_context():
warnings.warn(
"Using Variable.update from `airflow.models` is deprecated."
"Please use `set` on Variable from sdk(`airflow.sdk.Variable`) instead as it is an upsert.",
Expand Down Expand Up @@ -405,13 +386,7 @@ def delete(key: str, team_name: str | None = None, session: Session | None = Non
:param team_name: Team name associated to the task trying to delete the variable (if any)
:param session: optional session, use if provided or create a new one
"""
# TODO: This is not the best way of having compat, but it's "better than erroring" for now. This still
# means SQLA etc is loaded, but we can't avoid that unless/until we add import shims as a big
# back-compat layer

# If this is set it means are in some kind of execution context (Task, Dag Parse or Triggerer perhaps)
# and should use the Task SDK API server path
if hasattr(sys.modules.get("airflow.sdk.execution_time.task_runner"), "SUPERVISOR_COMMS"):
if is_client_process_context():
warnings.warn(
"Using Variable.delete from `airflow.models` is deprecated."
"Please use `delete` on Variable from sdk(`airflow.sdk.Variable`) instead",
Expand Down
2 changes: 1 addition & 1 deletion airflow-core/tests/unit/jobs/test_triggerer_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ def fn(moment): ...
assert "got an unexpected keyword argument 'not_exists_arg'" in str(err)

@pytest.mark.asyncio
@patch("airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True)
@patch("airflow.sdk.execution_time.task_runner.SupervisorComms._comms", create=True)
async def test_invalid_trigger(self, supervisor_builder):
"""Test the behaviour when we try to run an invalid Trigger"""
workload = workloads.RunTrigger.model_construct(
Expand Down
6 changes: 3 additions & 3 deletions airflow-core/tests/unit/models/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# under the License.
from __future__ import annotations

import os
import re
import sys
from typing import TYPE_CHECKING
Expand Down Expand Up @@ -353,6 +354,7 @@ def test_extra_dejson(self):
}

@mock.patch("airflow.sdk.Connection.get")
@mock.patch.dict(os.environ, {"_AIRFLOW_PROCESS_CONTEXT": "client"})
def test_get_connection_from_secrets_task_sdk_success(self, mock_get):
"""Test the get_connection_from_secrets method with Task SDK success path."""
from airflow.sdk import Connection as SDKConnection
Expand All @@ -361,7 +363,6 @@ def test_get_connection_from_secrets_task_sdk_success(self, mock_get):
mock_get.return_value = expected_connection

mock_task_runner = mock.MagicMock()
mock_task_runner.SUPERVISOR_COMMS = True

with mock.patch.dict(sys.modules, {"airflow.sdk.execution_time.task_runner": mock_task_runner}):
result = Connection.get_connection_from_secrets("test_conn")
Expand All @@ -370,10 +371,10 @@ def test_get_connection_from_secrets_task_sdk_success(self, mock_get):
assert result.conn_type == "test_type"

@mock.patch("airflow.sdk.Connection")
@mock.patch.dict(os.environ, {"_AIRFLOW_PROCESS_CONTEXT": "client"})
def test_get_connection_from_secrets_task_sdk_not_found(self, mock_task_sdk_connection):
"""Test the get_connection_from_secrets method with Task SDK not found path."""
mock_task_runner = mock.MagicMock()
mock_task_runner.SUPERVISOR_COMMS = True

mock_task_sdk_connection.get.side_effect = AirflowRuntimeError(
error=ErrorResponse(error=ErrorType.CONNECTION_NOT_FOUND)
Expand All @@ -383,7 +384,6 @@ def test_get_connection_from_secrets_task_sdk_not_found(self, mock_task_sdk_conn
with pytest.raises(AirflowNotFoundException):
Connection.get_connection_from_secrets("test_conn")

@mock.patch.dict(sys.modules, {"airflow.sdk.execution_time.task_runner": None})
@mock.patch("airflow.sdk.Connection")
@mock.patch("airflow.secrets.environment_variables.EnvironmentVariablesBackend.get_connection")
@mock.patch("airflow.secrets.metastore.MetastoreBackend.get_connection")
Expand Down
35 changes: 32 additions & 3 deletions devel-common/src/tests_common/pytest_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2227,7 +2227,7 @@ def override_caplog(request):
@pytest.fixture
def mock_supervisor_comms(monkeypatch):
# for back-compat
from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS
from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_2_PLUS

if not AIRFLOW_V_3_0_PLUS:
yield None
Expand All @@ -2239,13 +2239,42 @@ def mock_supervisor_comms(monkeypatch):
# core and TaskSDK is finished
if CommsDecoder := getattr(comms, "CommsDecoder", None):
comms = mock.create_autospec(CommsDecoder)
monkeypatch.setattr(task_runner, "SUPERVISOR_COMMS", comms, raising=False)
else:
CommsDecoder = getattr(task_runner, "CommsDecoder")
comms = mock.create_autospec(CommsDecoder)
comms.send = comms.get_message

if AIRFLOW_V_3_2_PLUS:
svcomms = task_runner.SupervisorComms()
old = svcomms.get_comms()
svcomms.set_comms(comms)
yield comms
svcomms.set_comms(old)
else:
monkeypatch.setattr(task_runner, "SUPERVISOR_COMMS", comms, raising=False)
yield comms
yield comms


@pytest.fixture
def mock_unset_supervisor_comms(monkeypatch):
# for back-compat
from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_2_PLUS

if not AIRFLOW_V_3_0_PLUS:
yield None
return

from airflow.sdk.execution_time import comms, task_runner

if AIRFLOW_V_3_2_PLUS:
svcomms = task_runner.SupervisorComms()
old = svcomms.get_comms()
svcomms.reset_comms()
yield comms
svcomms.set_comms(old)
else:
monkeypatch.setattr(task_runner, "SUPERVISOR_COMMS", None, raising=False)
yield comms


@pytest.fixture
Expand Down
Loading
Loading