diff --git a/task_sdk/src/airflow/sdk/__init__.py b/task_sdk/src/airflow/sdk/__init__.py index 13f2819a1c6e6..1117e946f8ea6 100644 --- a/task_sdk/src/airflow/sdk/__init__.py +++ b/task_sdk/src/airflow/sdk/__init__.py @@ -46,6 +46,7 @@ "EdgeModifier": ".definitions.edges", "Label": ".definitions.edges", "Connection": ".definitions.connection", + "Variable": ".definitions.variable", } diff --git a/task_sdk/src/airflow/sdk/api/client.py b/task_sdk/src/airflow/sdk/api/client.py index 4c6a02efe9451..0ef3efe95f14e 100644 --- a/task_sdk/src/airflow/sdk/api/client.py +++ b/task_sdk/src/airflow/sdk/api/client.py @@ -189,9 +189,20 @@ class VariableOperations: def __init__(self, client: Client): self.client = client - def get(self, key: str) -> VariableResponse: + def get(self, key: str) -> VariableResponse | ErrorResponse: """Get a variable from the API server.""" - resp = self.client.get(f"variables/{key}") + try: + resp = self.client.get(f"variables/{key}") + except ServerResponseError as e: + if e.response.status_code == HTTPStatus.NOT_FOUND: + log.error( + "Variable not found", + key=key, + detail=e.detail, + status_code=e.response.status_code, + ) + return ErrorResponse(error=ErrorType.VARIABLE_NOT_FOUND, detail={"key": key}) + raise return VariableResponse.model_validate_json(resp.read()) def set(self, key: str, value: str | None, description: str | None = None): diff --git a/task_sdk/src/airflow/sdk/definitions/variable.py b/task_sdk/src/airflow/sdk/definitions/variable.py new file mode 100644 index 0000000000000..5f458580065c5 --- /dev/null +++ b/task_sdk/src/airflow/sdk/definitions/variable.py @@ -0,0 +1,41 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from typing import Any + +import attrs + + +@attrs.define +class Variable: + """ + A generic way to store and retrieve arbitrary content or settings as a simple key/value store. + + :param key: The variable key. + :param value: The variable value. + :param description: The variable description. + + """ + + key: str + # keeping as any for supporting deserialize_json + value: Any | None = None + description: str | None = None + + # TODO: Extend this definition for reading/writing variables without context diff --git a/task_sdk/src/airflow/sdk/execution_time/comms.py b/task_sdk/src/airflow/sdk/execution_time/comms.py index b90787ca4cfc9..e1007d1f2fc05 100644 --- a/task_sdk/src/airflow/sdk/execution_time/comms.py +++ b/task_sdk/src/airflow/sdk/execution_time/comms.py @@ -121,7 +121,7 @@ def from_variable_response(cls, variable_response: VariableResponse) -> Variable VariableResponse is autogenerated from the API schema, so we need to convert it to VariableResult for communication between the Supervisor and the task process. """ - return cls(**variable_response.model_dump()) + return cls(**variable_response.model_dump(exclude_defaults=True), type="VariableResult") class ErrorResponse(BaseModel): diff --git a/task_sdk/src/airflow/sdk/execution_time/context.py b/task_sdk/src/airflow/sdk/execution_time/context.py index 30295a84f9d29..d96f5aeda2c95 100644 --- a/task_sdk/src/airflow/sdk/execution_time/context.py +++ b/task_sdk/src/airflow/sdk/execution_time/context.py @@ -21,19 +21,31 @@ import structlog from airflow.sdk.exceptions import AirflowRuntimeError, ErrorType +from airflow.sdk.types import NOTSET if TYPE_CHECKING: from airflow.sdk.definitions.connection import Connection - from airflow.sdk.execution_time.comms import ConnectionResult + from airflow.sdk.definitions.variable import Variable + from airflow.sdk.execution_time.comms import ConnectionResult, VariableResult -def _convert_connection_result_conn(conn_result: ConnectionResult): +def _convert_connection_result_conn(conn_result: ConnectionResult) -> Connection: from airflow.sdk.definitions.connection import Connection # `by_alias=True` is used to convert the `schema` field to `schema_` in the Connection model return Connection(**conn_result.model_dump(exclude={"type"}, by_alias=True)) +def _convert_variable_result_to_variable(var_result: VariableResult, deserialize_json: bool) -> Variable: + from airflow.sdk.definitions.variable import Variable + + if deserialize_json: + import json + + var_result.value = json.loads(var_result.value) # type: ignore + return Variable(**var_result.model_dump(exclude={"type"})) + + def _get_connection(conn_id: str) -> Connection: # TODO: This should probably be moved to a separate module like `airflow.sdk.execution_time.comms` # or `airflow.sdk.execution_time.connection` @@ -54,6 +66,26 @@ def _get_connection(conn_id: str) -> Connection: return _convert_connection_result_conn(msg) +def _get_variable(key: str, deserialize_json: bool) -> Variable: + # TODO: This should probably be moved to a separate module like `airflow.sdk.execution_time.comms` + # or `airflow.sdk.execution_time.variable` + # A reason to not move it to `airflow.sdk.execution_time.comms` is that it + # will make that module depend on Task SDK, which is not ideal because we intend to + # keep Task SDK as a separate package than execution time mods. + from airflow.sdk.execution_time.comms import ErrorResponse, GetVariable + from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + + log = structlog.get_logger(logger_name="task") + SUPERVISOR_COMMS.send_request(log=log, msg=GetVariable(key=key)) + msg = SUPERVISOR_COMMS.get_message() + if isinstance(msg, ErrorResponse): + raise AirflowRuntimeError(msg) + + if TYPE_CHECKING: + assert isinstance(msg, VariableResult) + return _convert_variable_result_to_variable(msg, deserialize_json) + + class ConnectionAccessor: """Wrapper to access Connection entries in template.""" @@ -76,3 +108,30 @@ def get(self, conn_id: str, default_conn: Any = None) -> Any: if e.error.error == ErrorType.CONNECTION_NOT_FOUND: return default_conn raise + + +class VariableAccessor: + """Wrapper to access Variable values in template.""" + + def __init__(self, deserialize_json: bool) -> None: + self._deserialize_json = deserialize_json + + def __eq__(self, other): + if not isinstance(other, VariableAccessor): + return False + # All instances of VariableAccessor are equal since it is a stateless dynamic accessor + return True + + def __repr__(self) -> str: + return "" + + def __getattr__(self, key: str) -> Any: + return _get_variable(key, self._deserialize_json) + + def get(self, key, default_var: Any = NOTSET) -> Any: + try: + return _get_variable(key, self._deserialize_json) + except AirflowRuntimeError as e: + if e.error.error == ErrorType.VARIABLE_NOT_FOUND: + return default_var + raise diff --git a/task_sdk/src/airflow/sdk/execution_time/supervisor.py b/task_sdk/src/airflow/sdk/execution_time/supervisor.py index 811d1ce86a60d..2363e927b468f 100644 --- a/task_sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task_sdk/src/airflow/sdk/execution_time/supervisor.py @@ -58,6 +58,7 @@ IntermediateTIState, TaskInstance, TerminalTIState, + VariableResponse, ) from airflow.sdk.execution_time.comms import ( ConnectionResult, @@ -722,8 +723,11 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger): resp = conn.model_dump_json().encode() elif isinstance(msg, GetVariable): var = self.client.variables.get(msg.key) - var_result = VariableResult.from_variable_response(var) - resp = var_result.model_dump_json().encode() + if isinstance(var, VariableResponse): + var_result = VariableResult.from_variable_response(var) + resp = var_result.model_dump_json(exclude_unset=True).encode() + else: + resp = var.model_dump_json().encode() elif isinstance(msg, GetXCom): xcom = self.client.xcoms.get(msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.map_index) xcom_result = XComResult.from_xcom_response(xcom) diff --git a/task_sdk/src/airflow/sdk/execution_time/task_runner.py b/task_sdk/src/airflow/sdk/execution_time/task_runner.py index effea9f6a7e52..4f8a4f0045c00 100644 --- a/task_sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task_sdk/src/airflow/sdk/execution_time/task_runner.py @@ -44,7 +44,7 @@ ToTask, XComResult, ) -from airflow.sdk.execution_time.context import ConnectionAccessor +from airflow.sdk.execution_time.context import ConnectionAccessor, VariableAccessor if TYPE_CHECKING: from structlog.typing import FilteringBoundLogger as Logger @@ -85,10 +85,10 @@ def get_template_context(self): # "prev_end_date_success": get_prev_end_date_success(), # "test_mode": task_instance.test_mode, # "triggering_asset_events": lazy_object_proxy.Proxy(get_triggering_events), - # "var": { - # "json": VariableAccessor(deserialize_json=True), - # "value": VariableAccessor(deserialize_json=False), - # }, + "var": { + "json": VariableAccessor(deserialize_json=True), + "value": VariableAccessor(deserialize_json=False), + }, "conn": ConnectionAccessor(), } if self._ti_context_from_server: diff --git a/task_sdk/tests/api/test_client.py b/task_sdk/tests/api/test_client.py index ff686537a96ed..8315a121fc4d7 100644 --- a/task_sdk/tests/api/test_client.py +++ b/task_sdk/tests/api/test_client.py @@ -362,16 +362,32 @@ def handle_request(request: httpx.Request) -> httpx.Response: client = make_client(transport=httpx.MockTransport(handle_request)) - with pytest.raises(ServerResponseError) as err: - client.variables.get(key="non_existent_var") - - assert err.value.response.status_code == 404 - assert err.value.detail == { - "detail": { - "message": "Variable with key 'non_existent_var' not found", - "reason": "not_found", - } - } + resp = client.variables.get(key="non_existent_var") + + assert isinstance(resp, ErrorResponse) + assert resp.error == ErrorType.VARIABLE_NOT_FOUND + assert resp.detail == {"key": "non_existent_var"} + + @mock.patch("time.sleep", return_value=None) + def test_variable_get_500_error(self, mock_sleep): + # Simulate a response from the server returning a 500 error + def handle_request(request: httpx.Request) -> httpx.Response: + if request.url.path == "/variables/test_key": + return httpx.Response( + status_code=500, + headers=[("content-Type", "application/json")], + json={ + "reason": "internal_server_error", + "message": "Internal Server Error", + }, + ) + return httpx.Response(status_code=400, json={"detail": "Bad Request"}) + + client = make_client(transport=httpx.MockTransport(handle_request)) + with pytest.raises(ServerResponseError): + client.variables.get( + key="test_key", + ) def test_variable_set_success(self): # Simulate a successful response from the server when putting a variable diff --git a/task_sdk/tests/execution_time/test_context.py b/task_sdk/tests/execution_time/test_context.py index 34502a1a91717..4aff0da0775eb 100644 --- a/task_sdk/tests/execution_time/test_context.py +++ b/task_sdk/tests/execution_time/test_context.py @@ -18,9 +18,15 @@ from __future__ import annotations from airflow.sdk.definitions.connection import Connection +from airflow.sdk.definitions.variable import Variable from airflow.sdk.exceptions import ErrorType -from airflow.sdk.execution_time.comms import ConnectionResult, ErrorResponse -from airflow.sdk.execution_time.context import ConnectionAccessor, _convert_connection_result_conn +from airflow.sdk.execution_time.comms import ConnectionResult, ErrorResponse, VariableResult +from airflow.sdk.execution_time.context import ( + ConnectionAccessor, + VariableAccessor, + _convert_connection_result_conn, + _convert_variable_result_to_variable, +) def test_convert_connection_result_conn(): @@ -48,6 +54,31 @@ def test_convert_connection_result_conn(): ) +def test_convert_variable_result_to_variable(): + """Test that the VariableResult is converted to a Variable object.""" + var = VariableResult( + key="test_key", + value="test_value", + ) + var = _convert_variable_result_to_variable(var, deserialize_json=False) + assert var == Variable( + key="test_key", + value="test_value", + ) + + +def test_convert_variable_result_to_variable_with_deserialize_json(): + """Test that the VariableResult is converted to a Variable object with deserialize_json set to True.""" + var = VariableResult( + key="test_key", + value='{\r\n "key1": "value1",\r\n "key2": "value2",\r\n "enabled": true,\r\n "threshold": 42\r\n}', + ) + var = _convert_variable_result_to_variable(var, deserialize_json=True) + assert var == Variable( + key="test_key", value={"key1": "value1", "key2": "value2", "enabled": True, "threshold": 42} + ) + + class TestConnectionAccessor: def test_getattr_connection(self, mock_supervisor_comms): """ @@ -90,3 +121,44 @@ def test_get_method_with_default(self, mock_supervisor_comms): conn = accessor.get("nonexistent_conn", default_conn=default_conn) assert conn == default_conn + + +class TestVariableAccessor: + def test_getattr_variable(self, mock_supervisor_comms): + """ + Test that the variable is fetched when accessed via __getattr__. + """ + accessor = VariableAccessor(deserialize_json=False) + + # Variable from the supervisor / API Server + var_result = VariableResult(key="test_key", value="test_value") + + mock_supervisor_comms.get_message.return_value = var_result + + # Fetch the variable; triggers __getattr__ + var = accessor.test_key + + expected_var = Variable(key="test_key", value="test_value") + assert var == expected_var + + def test_get_method_valid_variable(self, mock_supervisor_comms): + """Test that the get method returns the requested variable using `var.get`.""" + accessor = VariableAccessor(deserialize_json=False) + var_result = VariableResult(key="test_key", value="test_value") + + mock_supervisor_comms.get_message.return_value = var_result + + var = accessor.get("test_key") + assert var == Variable(key="test_key", value="test_value") + + def test_get_method_with_default(self, mock_supervisor_comms): + """Test that the get method returns the default variable when the requested variable is not found.""" + + accessor = VariableAccessor(deserialize_json=False) + default_var = {"default_key": "default_value"} + error_response = ErrorResponse(error=ErrorType.VARIABLE_NOT_FOUND, detail={"test_key": "test_value"}) + + mock_supervisor_comms.get_message.return_value = error_response + + var = accessor.get("nonexistent_var_key", default_var=default_var) + assert var == default_var diff --git a/task_sdk/tests/execution_time/test_task_runner.py b/task_sdk/tests/execution_time/test_task_runner.py index ebf03d3323315..8749bb2be1085 100644 --- a/task_sdk/tests/execution_time/test_task_runner.py +++ b/task_sdk/tests/execution_time/test_task_runner.py @@ -35,15 +35,18 @@ ) from airflow.sdk import DAG, BaseOperator, Connection from airflow.sdk.api.datamodels._generated import TaskInstance, TerminalTIState +from airflow.sdk.definitions.variable import Variable from airflow.sdk.execution_time.comms import ( ConnectionResult, DeferTask, GetConnection, + GetVariable, SetRenderedFields, StartupDetails, TaskState, + VariableResult, ) -from airflow.sdk.execution_time.context import ConnectionAccessor +from airflow.sdk.execution_time.context import ConnectionAccessor, VariableAccessor from airflow.sdk.execution_time.task_runner import ( CommsDecoder, RuntimeTaskInstance, @@ -590,6 +593,10 @@ def test_get_context_without_ti_context_from_server(self, mocked_parse, make_ti_ # Verify the context keys and values assert context == { + "var": { + "json": VariableAccessor(deserialize_json=True), + "value": VariableAccessor(deserialize_json=False), + }, "conn": ConnectionAccessor(), "dag": runtime_ti.task.dag, "inlets": task.inlets, @@ -623,6 +630,10 @@ def test_get_context_with_ti_context_from_server(self, mocked_parse, make_ti_con context = runtime_ti.get_template_context() assert context == { + "var": { + "json": VariableAccessor(deserialize_json=True), + "value": VariableAccessor(deserialize_json=False), + }, "conn": ConnectionAccessor(), "dag": runtime_ti.task.dag, "inlets": task.inlets, @@ -695,6 +706,51 @@ def test_get_connection_from_context(self, mocked_parse, make_ti_context, mock_s extra='{"extra_key": "extra_value"}', ) + @pytest.mark.parametrize( + ["accessor_type", "var_value", "expected_value"], + [ + pytest.param("value", "test_value", "test_value"), + pytest.param( + "json", + '{\r\n "key1": "value1",\r\n "key2": "value2",\r\n "enabled": true,\r\n "threshold": 42\r\n}', + {"key1": "value1", "key2": "value2", "enabled": True, "threshold": 42}, + ), + ], + ) + def test_get_variable_from_context( + self, mocked_parse, make_ti_context, mock_supervisor_comms, accessor_type, var_value, expected_value + ): + """Test that the variable is fetched from the API server via the Supervisor lazily when accessed""" + + task = BaseOperator(task_id="hello") + + ti_id = uuid7() + ti = TaskInstance( + id=ti_id, task_id=task.task_id, dag_id="basic_task", run_id="test_run", try_number=1 + ) + var = VariableResult(key="test_key", value=var_value) + + what = StartupDetails(ti=ti, file="", requests_fd=0, ti_context=make_ti_context()) + runtime_ti = mocked_parse(what, ti.dag_id, task) + mock_supervisor_comms.get_message.return_value = var + + context = runtime_ti.get_template_context() + + # Assert that the variable is not fetched from the API server yet! + # The variable should be only fetched connection is accessed + mock_supervisor_comms.send_request.assert_not_called() + mock_supervisor_comms.get_message.assert_not_called() + + # Access the variable from the context + var_from_context = context["var"][accessor_type].test_key + + mock_supervisor_comms.send_request.assert_called_once_with( + log=mock.ANY, msg=GetVariable(key="test_key") + ) + mock_supervisor_comms.get_message.assert_called_once_with() + + assert var_from_context == Variable(key="test_key", value=expected_value) + class TestXComAfterTaskExecution: @pytest.mark.parametrize(