Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions task_sdk/src/airflow/sdk/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
"EdgeModifier": ".definitions.edges",
"Label": ".definitions.edges",
"Connection": ".definitions.connection",
"Variable": ".definitions.variable",
}


Expand Down
15 changes: 13 additions & 2 deletions task_sdk/src/airflow/sdk/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,9 +189,20 @@ class VariableOperations:
def __init__(self, client: Client):
self.client = client

def get(self, key: str) -> VariableResponse:
def get(self, key: str) -> VariableResponse | ErrorResponse:
"""Get a variable from the API server."""
resp = self.client.get(f"variables/{key}")
try:
resp = self.client.get(f"variables/{key}")
except ServerResponseError as e:
if e.response.status_code == HTTPStatus.NOT_FOUND:
log.error(
"Variable not found",
key=key,
detail=e.detail,
status_code=e.response.status_code,
)
return ErrorResponse(error=ErrorType.VARIABLE_NOT_FOUND, detail={"key": key})
raise
return VariableResponse.model_validate_json(resp.read())

def set(self, key: str, value: str | None, description: str | None = None):
Expand Down
41 changes: 41 additions & 0 deletions task_sdk/src/airflow/sdk/definitions/variable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

from typing import Any

import attrs


@attrs.define
class Variable:
"""
A generic way to store and retrieve arbitrary content or settings as a simple key/value store.

:param key: The variable key.
:param value: The variable value.
:param description: The variable description.

"""

key: str
# keeping as any for supporting deserialize_json
value: Any | None = None
description: str | None = None

# TODO: Extend this definition for reading/writing variables without context
2 changes: 1 addition & 1 deletion task_sdk/src/airflow/sdk/execution_time/comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def from_variable_response(cls, variable_response: VariableResponse) -> Variable
VariableResponse is autogenerated from the API schema, so we need to convert it to VariableResult
for communication between the Supervisor and the task process.
"""
return cls(**variable_response.model_dump())
return cls(**variable_response.model_dump(exclude_defaults=True), type="VariableResult")


class ErrorResponse(BaseModel):
Expand Down
63 changes: 61 additions & 2 deletions task_sdk/src/airflow/sdk/execution_time/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,31 @@
import structlog

from airflow.sdk.exceptions import AirflowRuntimeError, ErrorType
from airflow.sdk.types import NOTSET

if TYPE_CHECKING:
from airflow.sdk.definitions.connection import Connection
from airflow.sdk.execution_time.comms import ConnectionResult
from airflow.sdk.definitions.variable import Variable
from airflow.sdk.execution_time.comms import ConnectionResult, VariableResult


def _convert_connection_result_conn(conn_result: ConnectionResult):
def _convert_connection_result_conn(conn_result: ConnectionResult) -> Connection:
from airflow.sdk.definitions.connection import Connection

# `by_alias=True` is used to convert the `schema` field to `schema_` in the Connection model
return Connection(**conn_result.model_dump(exclude={"type"}, by_alias=True))


def _convert_variable_result_to_variable(var_result: VariableResult, deserialize_json: bool) -> Variable:
from airflow.sdk.definitions.variable import Variable

if deserialize_json:
import json

var_result.value = json.loads(var_result.value) # type: ignore
return Variable(**var_result.model_dump(exclude={"type"}))


def _get_connection(conn_id: str) -> Connection:
# TODO: This should probably be moved to a separate module like `airflow.sdk.execution_time.comms`
# or `airflow.sdk.execution_time.connection`
Expand All @@ -54,6 +66,26 @@ def _get_connection(conn_id: str) -> Connection:
return _convert_connection_result_conn(msg)


def _get_variable(key: str, deserialize_json: bool) -> Variable:
# TODO: This should probably be moved to a separate module like `airflow.sdk.execution_time.comms`
# or `airflow.sdk.execution_time.variable`
# A reason to not move it to `airflow.sdk.execution_time.comms` is that it
# will make that module depend on Task SDK, which is not ideal because we intend to
# keep Task SDK as a separate package than execution time mods.
from airflow.sdk.execution_time.comms import ErrorResponse, GetVariable
from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS

log = structlog.get_logger(logger_name="task")
SUPERVISOR_COMMS.send_request(log=log, msg=GetVariable(key=key))
msg = SUPERVISOR_COMMS.get_message()
if isinstance(msg, ErrorResponse):
raise AirflowRuntimeError(msg)

if TYPE_CHECKING:
assert isinstance(msg, VariableResult)
return _convert_variable_result_to_variable(msg, deserialize_json)


class ConnectionAccessor:
"""Wrapper to access Connection entries in template."""

Expand All @@ -76,3 +108,30 @@ def get(self, conn_id: str, default_conn: Any = None) -> Any:
if e.error.error == ErrorType.CONNECTION_NOT_FOUND:
return default_conn
raise


class VariableAccessor:
"""Wrapper to access Variable values in template."""

def __init__(self, deserialize_json: bool) -> None:
self._deserialize_json = deserialize_json

def __eq__(self, other):
if not isinstance(other, VariableAccessor):
return False
# All instances of VariableAccessor are equal since it is a stateless dynamic accessor
return True

def __repr__(self) -> str:
return "<VariableAccessor (dynamic access)>"

def __getattr__(self, key: str) -> Any:
return _get_variable(key, self._deserialize_json)

def get(self, key, default_var: Any = NOTSET) -> Any:
try:
return _get_variable(key, self._deserialize_json)
except AirflowRuntimeError as e:
if e.error.error == ErrorType.VARIABLE_NOT_FOUND:
return default_var
raise
8 changes: 6 additions & 2 deletions task_sdk/src/airflow/sdk/execution_time/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
IntermediateTIState,
TaskInstance,
TerminalTIState,
VariableResponse,
)
from airflow.sdk.execution_time.comms import (
ConnectionResult,
Expand Down Expand Up @@ -722,8 +723,11 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger):
resp = conn.model_dump_json().encode()
elif isinstance(msg, GetVariable):
var = self.client.variables.get(msg.key)
var_result = VariableResult.from_variable_response(var)
resp = var_result.model_dump_json().encode()
if isinstance(var, VariableResponse):
var_result = VariableResult.from_variable_response(var)
resp = var_result.model_dump_json(exclude_unset=True).encode()
else:
resp = var.model_dump_json().encode()
elif isinstance(msg, GetXCom):
xcom = self.client.xcoms.get(msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.map_index)
xcom_result = XComResult.from_xcom_response(xcom)
Expand Down
10 changes: 5 additions & 5 deletions task_sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
ToTask,
XComResult,
)
from airflow.sdk.execution_time.context import ConnectionAccessor
from airflow.sdk.execution_time.context import ConnectionAccessor, VariableAccessor

if TYPE_CHECKING:
from structlog.typing import FilteringBoundLogger as Logger
Expand Down Expand Up @@ -85,10 +85,10 @@ def get_template_context(self):
# "prev_end_date_success": get_prev_end_date_success(),
# "test_mode": task_instance.test_mode,
# "triggering_asset_events": lazy_object_proxy.Proxy(get_triggering_events),
# "var": {
# "json": VariableAccessor(deserialize_json=True),
# "value": VariableAccessor(deserialize_json=False),
# },
"var": {
"json": VariableAccessor(deserialize_json=True),
"value": VariableAccessor(deserialize_json=False),
},
"conn": ConnectionAccessor(),
}
if self._ti_context_from_server:
Expand Down
36 changes: 26 additions & 10 deletions task_sdk/tests/api/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,16 +362,32 @@ def handle_request(request: httpx.Request) -> httpx.Response:

client = make_client(transport=httpx.MockTransport(handle_request))

with pytest.raises(ServerResponseError) as err:
client.variables.get(key="non_existent_var")

assert err.value.response.status_code == 404
assert err.value.detail == {
"detail": {
"message": "Variable with key 'non_existent_var' not found",
"reason": "not_found",
}
}
resp = client.variables.get(key="non_existent_var")

assert isinstance(resp, ErrorResponse)
assert resp.error == ErrorType.VARIABLE_NOT_FOUND
assert resp.detail == {"key": "non_existent_var"}

@mock.patch("time.sleep", return_value=None)
def test_variable_get_500_error(self, mock_sleep):
# Simulate a response from the server returning a 500 error
def handle_request(request: httpx.Request) -> httpx.Response:
if request.url.path == "/variables/test_key":
return httpx.Response(
status_code=500,
headers=[("content-Type", "application/json")],
json={
"reason": "internal_server_error",
"message": "Internal Server Error",
},
)
return httpx.Response(status_code=400, json={"detail": "Bad Request"})

client = make_client(transport=httpx.MockTransport(handle_request))
with pytest.raises(ServerResponseError):
client.variables.get(
key="test_key",
)

def test_variable_set_success(self):
# Simulate a successful response from the server when putting a variable
Expand Down
76 changes: 74 additions & 2 deletions task_sdk/tests/execution_time/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,15 @@
from __future__ import annotations

from airflow.sdk.definitions.connection import Connection
from airflow.sdk.definitions.variable import Variable
from airflow.sdk.exceptions import ErrorType
from airflow.sdk.execution_time.comms import ConnectionResult, ErrorResponse
from airflow.sdk.execution_time.context import ConnectionAccessor, _convert_connection_result_conn
from airflow.sdk.execution_time.comms import ConnectionResult, ErrorResponse, VariableResult
from airflow.sdk.execution_time.context import (
ConnectionAccessor,
VariableAccessor,
_convert_connection_result_conn,
_convert_variable_result_to_variable,
)


def test_convert_connection_result_conn():
Expand Down Expand Up @@ -48,6 +54,31 @@ def test_convert_connection_result_conn():
)


def test_convert_variable_result_to_variable():
"""Test that the VariableResult is converted to a Variable object."""
var = VariableResult(
key="test_key",
value="test_value",
)
var = _convert_variable_result_to_variable(var, deserialize_json=False)
assert var == Variable(
key="test_key",
value="test_value",
)


def test_convert_variable_result_to_variable_with_deserialize_json():
"""Test that the VariableResult is converted to a Variable object with deserialize_json set to True."""
var = VariableResult(
key="test_key",
value='{\r\n "key1": "value1",\r\n "key2": "value2",\r\n "enabled": true,\r\n "threshold": 42\r\n}',
)
var = _convert_variable_result_to_variable(var, deserialize_json=True)
assert var == Variable(
key="test_key", value={"key1": "value1", "key2": "value2", "enabled": True, "threshold": 42}
)


class TestConnectionAccessor:
def test_getattr_connection(self, mock_supervisor_comms):
"""
Expand Down Expand Up @@ -90,3 +121,44 @@ def test_get_method_with_default(self, mock_supervisor_comms):

conn = accessor.get("nonexistent_conn", default_conn=default_conn)
assert conn == default_conn


class TestVariableAccessor:
def test_getattr_variable(self, mock_supervisor_comms):
"""
Test that the variable is fetched when accessed via __getattr__.
"""
accessor = VariableAccessor(deserialize_json=False)

# Variable from the supervisor / API Server
var_result = VariableResult(key="test_key", value="test_value")

mock_supervisor_comms.get_message.return_value = var_result

# Fetch the variable; triggers __getattr__
var = accessor.test_key

expected_var = Variable(key="test_key", value="test_value")
assert var == expected_var

def test_get_method_valid_variable(self, mock_supervisor_comms):
"""Test that the get method returns the requested variable using `var.get`."""
accessor = VariableAccessor(deserialize_json=False)
var_result = VariableResult(key="test_key", value="test_value")

mock_supervisor_comms.get_message.return_value = var_result

var = accessor.get("test_key")
assert var == Variable(key="test_key", value="test_value")

def test_get_method_with_default(self, mock_supervisor_comms):
"""Test that the get method returns the default variable when the requested variable is not found."""

accessor = VariableAccessor(deserialize_json=False)
default_var = {"default_key": "default_value"}
error_response = ErrorResponse(error=ErrorType.VARIABLE_NOT_FOUND, detail={"test_key": "test_value"})

mock_supervisor_comms.get_message.return_value = error_response

var = accessor.get("nonexistent_var_key", default_var=default_var)
assert var == default_var
Loading