diff --git a/airflow-core/src/airflow/dag_processing/processor.py b/airflow-core/src/airflow/dag_processing/processor.py index b6e07242dce30..365d771a081d9 100644 --- a/airflow-core/src/airflow/dag_processing/processor.py +++ b/airflow-core/src/airflow/dag_processing/processor.py @@ -38,6 +38,7 @@ ErrorResponse, GetConnection, GetVariable, + PutVariable, VariableResult, ) from airflow.sdk.execution_time.supervisor import WatchedSubprocess @@ -53,7 +54,7 @@ from airflow.typing_compat import Self ToManager = Annotated[ - Union["DagFileParsingResult", GetConnection, GetVariable], + Union["DagFileParsingResult", GetConnection, GetVariable, PutVariable], Field(discriminator="type"), ] @@ -290,6 +291,8 @@ def _handle_request(self, msg: ToManager, log: FilteringBoundLogger) -> None: # dump_opts = {"exclude_unset": True} else: resp = var + elif isinstance(msg, PutVariable): + self.client.variables.set(msg.key, msg.value, msg.description) else: log.error("Unhandled request", msg=msg) return diff --git a/airflow-core/src/airflow/models/variable.py b/airflow-core/src/airflow/models/variable.py index e0b9168628566..9c6af8ab7275a 100644 --- a/airflow-core/src/airflow/models/variable.py +++ b/airflow-core/src/airflow/models/variable.py @@ -201,6 +201,30 @@ def set( """ # check if the secret exists in the custom secrets' backend. Variable.check_for_write_conflict(key=key) + + # TODO: This is not the best way of having compat, but it's "better than erroring" for now. This still + # means SQLA etc is loaded, but we can't avoid that unless/until we add import shims as a big + # back-compat layer + + # If this is set it means are in some kind of execution context (Task, Dag Parse or Triggerer perhaps) + # and should use the Task SDK API server path + if hasattr(sys.modules.get("airflow.sdk.execution_time.task_runner"), "SUPERVISOR_COMMS"): + warnings.warn( + "Using Variable.set from `airflow.models` is deprecated. Please use `from airflow.sdk import" + "Variable` instead", + DeprecationWarning, + stacklevel=1, + ) + from airflow.sdk import Variable as TaskSDKVariable + + TaskSDKVariable.set( + key=key, + value=value, + description=description, + serialize_json=serialize_json, + ) + return + if serialize_json: stored_value = json.dumps(value, indent=2) else: diff --git a/airflow-core/tests/unit/dag_processing/test_processor.py b/airflow-core/tests/unit/dag_processing/test_processor.py index ca9670f81ae4a..4135c617f39bb 100644 --- a/airflow-core/tests/unit/dag_processing/test_processor.py +++ b/airflow-core/tests/unit/dag_processing/test_processor.py @@ -181,6 +181,36 @@ def dag_in_a_fn(): if result.import_errors: assert "VARIABLE_NOT_FOUND" in next(iter(result.import_errors.values())) + def test_top_level_variable_set(self, tmp_path: pathlib.Path): + from airflow.models.variable import Variable as VariableORM + + logger_filehandle = MagicMock() + + def dag_in_a_fn(): + from airflow.sdk import DAG, Variable + + Variable.set(key="mykey", value="myvalue") + with DAG(f"test_{Variable.get('mykey')}"): + ... + + path = write_dag_in_a_fn_to_file(dag_in_a_fn, tmp_path) + proc = DagFileProcessorProcess.start( + id=1, path=path, bundle_path=tmp_path, callbacks=[], logger_filehandle=logger_filehandle + ) + + while not proc.is_ready: + proc._service_subprocess(0.1) + + with create_session() as session: + result = proc.parsing_result + assert result is not None + assert result.import_errors == {} + assert result.serialized_dags[0].dag_id == "test_myvalue" + + all_vars = session.query(VariableORM).all() + assert len(all_vars) == 1 + assert all_vars[0].key == "mykey" + def test_top_level_connection_access(self, tmp_path: pathlib.Path, monkeypatch: pytest.MonkeyPatch): logger_filehandle = MagicMock() diff --git a/task-sdk/src/airflow/sdk/definitions/variable.py b/task-sdk/src/airflow/sdk/definitions/variable.py index 87b0ee29fab50..9e80b8a566717 100644 --- a/task-sdk/src/airflow/sdk/definitions/variable.py +++ b/task-sdk/src/airflow/sdk/definitions/variable.py @@ -55,3 +55,13 @@ def get(cls, key: str, default: Any = NOTSET, deserialize_json: bool = False): if e.error.error == ErrorType.VARIABLE_NOT_FOUND and default is not NOTSET: return default raise + + @classmethod + def set(cls, key: str, value: Any, description: str | None = None, serialize_json: bool = False) -> None: + from airflow.sdk.exceptions import AirflowRuntimeError + from airflow.sdk.execution_time.context import _set_variable + + try: + return _set_variable(key, value, description, serialize_json=serialize_json) + except AirflowRuntimeError as e: + log.exception(e) diff --git a/task-sdk/src/airflow/sdk/execution_time/context.py b/task-sdk/src/airflow/sdk/execution_time/context.py index 2ff7dbbda0871..fee91f3efafc4 100644 --- a/task-sdk/src/airflow/sdk/execution_time/context.py +++ b/task-sdk/src/airflow/sdk/execution_time/context.py @@ -211,6 +211,52 @@ def _get_variable(key: str, deserialize_json: bool) -> Any: return variable.value +def _set_variable(key: str, value: Any, description: str | None = None, serialize_json: bool = False) -> None: + # TODO: This should probably be moved to a separate module like `airflow.sdk.execution_time.comms` + # or `airflow.sdk.execution_time.variable` + # A reason to not move it to `airflow.sdk.execution_time.comms` is that it + # will make that module depend on Task SDK, which is not ideal because we intend to + # keep Task SDK as a separate package than execution time mods. + import json + + from airflow.sdk.execution_time.comms import PutVariable + from airflow.sdk.execution_time.supervisor import ensure_secrets_backend_loaded + from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + + # check for write conflicts on the worker + for secrets_backend in ensure_secrets_backend_loaded(): + try: + var_val = secrets_backend.get_variable(key=key) + if var_val is not None: + _backend_name = type(secrets_backend).__name__ + log.warning( + "The variable %s is defined in the %s secrets backend, which takes " + "precedence over reading from the database. The value in the database will be " + "updated, but to read it you have to delete the conflicting variable " + "from %s", + key, + _backend_name, + _backend_name, + ) + except Exception: + log.exception( + "Unable to retrieve variable from secrets backend (%s). Checking subsequent secrets backend.", + type(secrets_backend).__name__, + ) + + try: + if serialize_json: + value = json.dumps(value, indent=2) + except Exception as e: + log.exception(e) + + # It is best to have lock everywhere or nowhere on the SUPERVISOR_COMMS, lock was + # primarily added for triggers but it doesn't make sense to have it in some places + # and not in the rest. A lot of this will be simplified by https://github.com/apache/airflow/issues/46426 + with SUPERVISOR_COMMS.lock: + SUPERVISOR_COMMS.send_request(log=log, msg=PutVariable(key=key, value=value, description=description)) + + class ConnectionAccessor: """Wrapper to access Connection entries in template.""" diff --git a/task-sdk/tests/task_sdk/definitions/test_variables.py b/task-sdk/tests/task_sdk/definitions/test_variables.py index 6560bdee90381..242c5af407bcd 100644 --- a/task-sdk/tests/task_sdk/definitions/test_variables.py +++ b/task-sdk/tests/task_sdk/definitions/test_variables.py @@ -17,13 +17,14 @@ # under the License. from __future__ import annotations +import json from unittest import mock import pytest from airflow.configuration import initialize_secrets_backends from airflow.sdk import Variable -from airflow.sdk.execution_time.comms import VariableResult +from airflow.sdk.execution_time.comms import PutVariable, VariableResult from airflow.secrets import DEFAULT_SECRETS_SEARCH_PATH_WORKERS from tests_common.test_utils.config import conf_vars @@ -55,6 +56,39 @@ def test_var_get(self, deserialize_json, value, expected_value, mock_supervisor_ assert var is not None assert var == expected_value + @pytest.mark.parametrize( + "key, value, description, serialize_json", + [ + pytest.param( + "key", + "value", + "description", + False, + id="simple-value", + ), + pytest.param( + "key2", + {"hi": "there", "hello": 42, "flag": True}, + "description2", + True, + id="serialize-json-value", + ), + ], + ) + def test_var_set(self, key, value, description, serialize_json, mock_supervisor_comms): + Variable.set(key=key, value=value, description=description, serialize_json=serialize_json) + + expected_value = value + if serialize_json: + expected_value = json.dumps(value, indent=2) + + mock_supervisor_comms.send_request.assert_called_once_with( + log=mock.ANY, + msg=PutVariable( + key=key, value=expected_value, description=description, serialize_json=serialize_json + ), + ) + class TestVariableFromSecrets: def test_var_get_from_secrets_found(self, mock_supervisor_comms, tmp_path):