Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AIP-72: Adding support to Set an Airflow Variable from Task SDK #44562

Merged
merged 11 commits into from
Dec 5, 2024
4 changes: 3 additions & 1 deletion airflow/api_fastapi/execution_api/datamodels/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from pydantic import Field

from airflow.api_fastapi.core_api.base import BaseModel
from airflow.api_fastapi.core_api.base import BaseModel, ConfigDict


class VariableResponse(BaseModel):
Expand All @@ -32,5 +32,7 @@ class VariableResponse(BaseModel):
class VariablePostBody(BaseModel):
"""Request body schema for creating variables."""

model_config = ConfigDict(extra="forbid")

value: str | None = Field(serialization_alias="val")
description: str | None = Field(default=None)
10 changes: 10 additions & 0 deletions task_sdk/src/airflow/sdk/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
TIHeartbeatInfo,
TITerminalStatePayload,
ValidationError as RemoteValidationError,
VariablePostBody,
VariableResponse,
XComResponse,
)
Expand Down Expand Up @@ -157,6 +158,15 @@ def get(self, key: str) -> VariableResponse:
resp = self.client.get(f"variables/{key}")
return VariableResponse.model_validate_json(resp.read())

def set(self, key: str, value: str | None, description: str | None = None):
"""Set an Airflow Variable via the API server."""
body = VariablePostBody(value=value, description=description)
self.client.put(f"variables/{key}", content=body.model_dump_json())
# Any error from the server will anyway be propagated down to the supervisor,
# so we choose to send a generic response to the supervisor over the server response to
# decouple from the server response string
return {"ok": True}


class XComOperations:
__slots__ = ("client",)
Expand Down
14 changes: 13 additions & 1 deletion task_sdk/src/airflow/sdk/api/datamodels/_generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from typing import Annotated, Any, Literal
from uuid import UUID

from pydantic import BaseModel, Field
from pydantic import BaseModel, ConfigDict, Field


class ConnectionResponse(BaseModel):
Expand Down Expand Up @@ -116,6 +116,18 @@ class ValidationError(BaseModel):
type: Annotated[str, Field(title="Error Type")]


class VariablePostBody(BaseModel):
"""
Request body schema for creating variables.
"""

model_config = ConfigDict(
extra="forbid",
)
value: Annotated[str | None, Field(title="Value")] = None
description: Annotated[str | None, Field(title="Description")] = None


class VariableResponse(BaseModel):
"""
Variable schema for responses with fields that are needed for Runtime.
Expand Down
9 changes: 8 additions & 1 deletion task_sdk/src/airflow/sdk/execution_time/comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,14 @@ class GetVariable(BaseModel):
type: Literal["GetVariable"] = "GetVariable"


class PutVariable(BaseModel):
key: str
value: str | None
description: str | None
type: Literal["PutVariable"] = "PutVariable"


ToSupervisor = Annotated[
Union[TaskState, GetXCom, GetConnection, GetVariable, DeferTask],
Union[TaskState, GetXCom, GetConnection, GetVariable, DeferTask, PutVariable],
Field(discriminator="type"),
]
11 changes: 10 additions & 1 deletion task_sdk/src/airflow/sdk/execution_time/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,17 @@
from pydantic import TypeAdapter

from airflow.sdk.api.client import Client, ServerResponseError
from airflow.sdk.api.datamodels._generated import IntermediateTIState, TaskInstance, TerminalTIState
from airflow.sdk.api.datamodels._generated import (
IntermediateTIState,
TaskInstance,
TerminalTIState,
)
from airflow.sdk.execution_time.comms import (
DeferTask,
GetConnection,
GetVariable,
GetXCom,
PutVariable,
StartupDetails,
TaskState,
ToSupervisor,
Expand Down Expand Up @@ -669,6 +674,10 @@ def handle_requests(self, log: FilteringBoundLogger) -> Generator[None, bytes, N
elif isinstance(msg, DeferTask):
self._terminal_state = IntermediateTIState.DEFERRED
self.client.task_instances.defer(self.ti_id, msg)
resp = None
elif isinstance(msg, PutVariable):
self.client.variables.set(msg.key, msg.value, msg.description)
resp = None
else:
log.error("Unhandled request", msg=msg)
continue
Expand Down
15 changes: 15 additions & 0 deletions task_sdk/tests/api/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,3 +133,18 @@ def handle_request(request: httpx.Request) -> httpx.Response:
"reason": "not_found",
}
}

def test_variable_set_success(self):
# Simulate a successful response from the server when putting a variable
def handle_request(request: httpx.Request) -> httpx.Response:
if request.url.path == "/variables/test_key":
return httpx.Response(
status_code=201,
json={"message": "Variable successfully set"},
)
return httpx.Response(status_code=400, json={"detail": "Bad Request"})

client = make_client(transport=httpx.MockTransport(handle_request))

result = client.variables.set(key="test_key", value="test_value", description="test_description")
assert result == {"ok": True}
9 changes: 9 additions & 0 deletions task_sdk/tests/execution_time/test_supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
GetConnection,
GetVariable,
GetXCom,
PutVariable,
VariableResult,
XComResult,
)
Expand Down Expand Up @@ -803,6 +804,14 @@ def watched_subprocess(self, mocker):
"",
id="patch_task_instance_to_deferred",
),
pytest.param(
PutVariable(key="test_key", value="test_value", description="test_description"),
b"",
"variables.set",
("test_key", "test_value", "test_description"),
{"ok": True},
id="set_variable",
),
],
)
def test_handle_requests(
Expand Down
25 changes: 24 additions & 1 deletion tests/api_fastapi/execution_api/routes/test_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def test_should_create_variable(self, client, payload, session):
pytest.param("var_create", 422, {"description": "description"}, id="missing-value"),
],
)
def test_variable_missing_fields(self, client, key, status_code, payload, session):
def test_variable_missing_mandatory_fields(self, client, key, status_code, payload, session):
response = client.put(
f"/execution/variables/{key}",
json=payload,
Expand All @@ -127,6 +127,29 @@ def test_variable_missing_fields(self, client, key, status_code, payload, sessio
assert response.json()["detail"][0]["type"] == "missing"
assert response.json()["detail"][0]["msg"] == "Field required"

@pytest.mark.parametrize(
"key, payload",
[
pytest.param("key", {"key": "key", "value": "{}", "description": "description"}, id="adding-key"),
pytest.param(
"key", {"type": "PutVariable", "value": "{}", "description": "description"}, id="adding-type"
),
pytest.param(
"key",
{"value": "{}", "description": "description", "lorem": "ipsum", "foo": "bar"},
id="adding-extra-fields",
),
],
)
def test_variable_adding_extra_fields(self, client, key, payload, session):
response = client.put(
f"/execution/variables/{key}",
json=payload,
)
assert response.status_code == 422
assert response.json()["detail"][0]["type"] == "extra_forbidden"
assert response.json()["detail"][0]["msg"] == "Extra inputs are not permitted"

def test_overwriting_existing_variable(self, client, session):
key = "var_create"
Variable.set(key=key, value="value", session=session)
Expand Down