Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion airflow-core/src/airflow/dag_processing/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
ErrorResponse,
GetConnection,
GetVariable,
PutVariable,
VariableResult,
)
from airflow.sdk.execution_time.supervisor import WatchedSubprocess
Expand All @@ -53,7 +54,7 @@
from airflow.typing_compat import Self

ToManager = Annotated[
Union["DagFileParsingResult", GetConnection, GetVariable],
Union["DagFileParsingResult", GetConnection, GetVariable, PutVariable],
Field(discriminator="type"),
]

Expand Down Expand Up @@ -290,6 +291,8 @@ def _handle_request(self, msg: ToManager, log: FilteringBoundLogger) -> None: #
dump_opts = {"exclude_unset": True}
else:
resp = var
elif isinstance(msg, PutVariable):
self.client.variables.set(msg.key, msg.value, msg.description)
else:
log.error("Unhandled request", msg=msg)
return
Expand Down
24 changes: 24 additions & 0 deletions airflow-core/src/airflow/models/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,30 @@ def set(
"""
# check if the secret exists in the custom secrets' backend.
Variable.check_for_write_conflict(key=key)

# TODO: This is not the best way of having compat, but it's "better than erroring" for now. This still
# means SQLA etc is loaded, but we can't avoid that unless/until we add import shims as a big
# back-compat layer

# If this is set it means are in some kind of execution context (Task, Dag Parse or Triggerer perhaps)
# and should use the Task SDK API server path
if hasattr(sys.modules.get("airflow.sdk.execution_time.task_runner"), "SUPERVISOR_COMMS"):
warnings.warn(
"Using Variable.set from `airflow.models` is deprecated. Please use `from airflow.sdk import"
"Variable` instead",
DeprecationWarning,
stacklevel=1,
)
from airflow.sdk import Variable as TaskSDKVariable

TaskSDKVariable.set(
key=key,
value=value,
description=description,
serialize_json=serialize_json,
)
return

if serialize_json:
stored_value = json.dumps(value, indent=2)
else:
Expand Down
30 changes: 30 additions & 0 deletions airflow-core/tests/unit/dag_processing/test_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,36 @@ def dag_in_a_fn():
if result.import_errors:
assert "VARIABLE_NOT_FOUND" in next(iter(result.import_errors.values()))

def test_top_level_variable_set(self, tmp_path: pathlib.Path):
from airflow.models.variable import Variable as VariableORM

logger_filehandle = MagicMock()

def dag_in_a_fn():
from airflow.sdk import DAG, Variable

Variable.set(key="mykey", value="myvalue")
with DAG(f"test_{Variable.get('mykey')}"):
...

path = write_dag_in_a_fn_to_file(dag_in_a_fn, tmp_path)
proc = DagFileProcessorProcess.start(
id=1, path=path, bundle_path=tmp_path, callbacks=[], logger_filehandle=logger_filehandle
)

while not proc.is_ready:
proc._service_subprocess(0.1)

with create_session() as session:
result = proc.parsing_result
assert result is not None
assert result.import_errors == {}
assert result.serialized_dags[0].dag_id == "test_myvalue"

all_vars = session.query(VariableORM).all()
assert len(all_vars) == 1
assert all_vars[0].key == "mykey"

def test_top_level_connection_access(self, tmp_path: pathlib.Path, monkeypatch: pytest.MonkeyPatch):
logger_filehandle = MagicMock()

Expand Down
10 changes: 10 additions & 0 deletions task-sdk/src/airflow/sdk/definitions/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,13 @@ def get(cls, key: str, default: Any = NOTSET, deserialize_json: bool = False):
if e.error.error == ErrorType.VARIABLE_NOT_FOUND and default is not NOTSET:
return default
raise

@classmethod
def set(cls, key: str, value: Any, description: str | None = None, serialize_json: bool = False) -> None:
from airflow.sdk.exceptions import AirflowRuntimeError
from airflow.sdk.execution_time.context import _set_variable

try:
return _set_variable(key, value, description, serialize_json=serialize_json)
except AirflowRuntimeError as e:
log.exception(e)
46 changes: 46 additions & 0 deletions task-sdk/src/airflow/sdk/execution_time/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,52 @@ def _get_variable(key: str, deserialize_json: bool) -> Any:
return variable.value


def _set_variable(key: str, value: Any, description: str | None = None, serialize_json: bool = False) -> None:
# TODO: This should probably be moved to a separate module like `airflow.sdk.execution_time.comms`
# or `airflow.sdk.execution_time.variable`
# A reason to not move it to `airflow.sdk.execution_time.comms` is that it
# will make that module depend on Task SDK, which is not ideal because we intend to
# keep Task SDK as a separate package than execution time mods.
import json

from airflow.sdk.execution_time.comms import PutVariable
from airflow.sdk.execution_time.supervisor import ensure_secrets_backend_loaded
from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS

# check for write conflicts on the worker
for secrets_backend in ensure_secrets_backend_loaded():
try:
var_val = secrets_backend.get_variable(key=key)
if var_val is not None:
_backend_name = type(secrets_backend).__name__
log.warning(
"The variable %s is defined in the %s secrets backend, which takes "
"precedence over reading from the database. The value in the database will be "
"updated, but to read it you have to delete the conflicting variable "
"from %s",
key,
_backend_name,
_backend_name,
)
except Exception:
log.exception(
"Unable to retrieve variable from secrets backend (%s). Checking subsequent secrets backend.",
type(secrets_backend).__name__,
)

try:
if serialize_json:
value = json.dumps(value, indent=2)
except Exception as e:
log.exception(e)

# It is best to have lock everywhere or nowhere on the SUPERVISOR_COMMS, lock was
# primarily added for triggers but it doesn't make sense to have it in some places
# and not in the rest. A lot of this will be simplified by https://github.com/apache/airflow/issues/46426
with SUPERVISOR_COMMS.lock:
SUPERVISOR_COMMS.send_request(log=log, msg=PutVariable(key=key, value=value, description=description))


class ConnectionAccessor:
"""Wrapper to access Connection entries in template."""

Expand Down
36 changes: 35 additions & 1 deletion task-sdk/tests/task_sdk/definitions/test_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,14 @@
# under the License.
from __future__ import annotations

import json
from unittest import mock

import pytest

from airflow.configuration import initialize_secrets_backends
from airflow.sdk import Variable
from airflow.sdk.execution_time.comms import VariableResult
from airflow.sdk.execution_time.comms import PutVariable, VariableResult
from airflow.secrets import DEFAULT_SECRETS_SEARCH_PATH_WORKERS

from tests_common.test_utils.config import conf_vars
Expand Down Expand Up @@ -55,6 +56,39 @@ def test_var_get(self, deserialize_json, value, expected_value, mock_supervisor_
assert var is not None
assert var == expected_value

@pytest.mark.parametrize(
"key, value, description, serialize_json",
[
pytest.param(
"key",
"value",
"description",
False,
id="simple-value",
),
pytest.param(
"key2",
{"hi": "there", "hello": 42, "flag": True},
"description2",
True,
id="serialize-json-value",
),
],
)
def test_var_set(self, key, value, description, serialize_json, mock_supervisor_comms):
Variable.set(key=key, value=value, description=description, serialize_json=serialize_json)

expected_value = value
if serialize_json:
expected_value = json.dumps(value, indent=2)

mock_supervisor_comms.send_request.assert_called_once_with(
log=mock.ANY,
msg=PutVariable(
key=key, value=expected_value, description=description, serialize_json=serialize_json
),
)


class TestVariableFromSecrets:
def test_var_get_from_secrets_found(self, mock_supervisor_comms, tmp_path):
Expand Down