Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 48 additions & 3 deletions task-sdk/src/airflow/sdk/execution_time/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,8 +266,25 @@ def _get_variable(key: str, deserialize_json: bool) -> Any:
)

# If no backend found the variable, raise a not found error (mirrors _get_connection)
from airflow.sdk.exceptions import AirflowRuntimeError, ErrorType
from airflow.sdk.execution_time.comms import ErrorResponse
from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS

if SUPERVISOR_COMMS is None:
raise AirflowRuntimeError(
ErrorResponse(
error=ErrorType.VARIABLE_NOT_FOUND,
detail={
"message": (
f"Variable '{key}' not found. Note: SUPERVISOR_COMMS is not available, "
"which means this code is running outside a task execution context "
"(e.g., at the top level of a DAG file). "
"Consider using environment variables (AIRFLOW_VAR_<key>), "
"Jinja templates ({{ var.value.<key> }}), "
"or move the Variable.get() call inside a task function."
)
},
)
)

raise AirflowRuntimeError(
ErrorResponse(error=ErrorType.VARIABLE_NOT_FOUND, detail={"message": f"Variable {key} not found"})
Expand All @@ -283,7 +300,7 @@ def _set_variable(key: str, value: Any, description: str | None = None, serializ
import json

from airflow.sdk.execution_time.cache import SecretCache
from airflow.sdk.execution_time.comms import PutVariable
from airflow.sdk.execution_time.comms import ErrorResponse, PutVariable
from airflow.sdk.execution_time.secrets.execution_api import ExecutionAPISecretsBackend
from airflow.sdk.execution_time.supervisor import ensure_secrets_backend_loaded
from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
Expand Down Expand Up @@ -317,6 +334,20 @@ def _set_variable(key: str, value: Any, description: str | None = None, serializ
except Exception as e:
log.exception(e)

if SUPERVISOR_COMMS is None:
raise AirflowRuntimeError(
ErrorResponse(
error=ErrorType.GENERIC_ERROR,
detail={
"message": (
"Variable.set() requires a task execution context (SUPERVISOR_COMMS is not available). "
"This typically happens when calling Variable.set() at the top level of a DAG file "
"or outside of a running task. Variable.set() can only be used inside a task."
)
},
)
)

SUPERVISOR_COMMS.send(PutVariable(key=key, value=value, description=description))

# Invalidate cache after setting the variable
Expand All @@ -330,9 +361,23 @@ def _delete_variable(key: str) -> None:
# will make that module depend on Task SDK, which is not ideal because we intend to
# keep Task SDK as a separate package than execution time mods.
from airflow.sdk.execution_time.cache import SecretCache
from airflow.sdk.execution_time.comms import DeleteVariable
from airflow.sdk.execution_time.comms import DeleteVariable, ErrorResponse
from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS

if SUPERVISOR_COMMS is None:
raise AirflowRuntimeError(
ErrorResponse(
error=ErrorType.GENERIC_ERROR,
detail={
"message": (
"Variable.delete() requires a task execution context (SUPERVISOR_COMMS is not available). "
"This typically happens when calling Variable.delete() at the top level of a DAG file "
"or outside of a running task. Variable.delete() can only be used inside a task."
)
},
)
)

msg = SUPERVISOR_COMMS.send(DeleteVariable(key=key))
if TYPE_CHECKING:
assert isinstance(msg, OKResponse)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ def get_connection(self, conn_id: str, team_name: str | None = None) -> Connecti
from airflow.sdk.execution_time.context import _process_connection_result_conn
from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS

if SUPERVISOR_COMMS is None:
return None

try:
msg = SUPERVISOR_COMMS.send(GetConnection(conn_id=conn_id))

Expand Down Expand Up @@ -102,6 +105,9 @@ def get_variable(self, key: str, team_name: str | None = None) -> str | None:
from airflow.sdk.execution_time.comms import ErrorResponse, GetVariable, VariableResult
from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS

if SUPERVISOR_COMMS is None:
return None

try:
msg = SUPERVISOR_COMMS.send(GetVariable(key=key))

Expand Down Expand Up @@ -129,6 +135,9 @@ async def aget_connection(self, conn_id: str) -> Connection | None: # type: ign
from airflow.sdk.execution_time.context import _process_connection_result_conn
from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS

if SUPERVISOR_COMMS is None:
return None

try:
msg = await SUPERVISOR_COMMS.asend(GetConnection(conn_id=conn_id))

Expand All @@ -153,6 +162,9 @@ async def aget_variable(self, key: str) -> str | None:
from airflow.sdk.execution_time.comms import ErrorResponse, GetVariable, VariableResult
from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS

if SUPERVISOR_COMMS is None:
return None

try:
msg = await SUPERVISOR_COMMS.asend(GetVariable(key=key))

Expand Down
2 changes: 1 addition & 1 deletion task-sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,7 +797,7 @@ def parse(what: StartupDetails, log: Logger) -> RuntimeTaskInstance:
# deeply nested execution stack.
# - By defining `SUPERVISOR_COMMS` as a global, it ensures that this communication mechanism is readily
# accessible wherever needed during task execution without modifying every layer of the call stack.
SUPERVISOR_COMMS: CommsDecoder[ToTask, ToSupervisor]
SUPERVISOR_COMMS: CommsDecoder[ToTask, ToSupervisor] | None = None


# State machine!
Expand Down
46 changes: 46 additions & 0 deletions task-sdk/tests/task_sdk/definitions/test_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from airflow.sdk import Variable
from airflow.sdk.configuration import initialize_secrets_backends
from airflow.sdk.exceptions import AirflowRuntimeError
from airflow.sdk.execution_time.comms import PutVariable, VariableResult
from airflow.sdk.execution_time.secrets import DEFAULT_SECRETS_SEARCH_PATH_WORKERS

Expand Down Expand Up @@ -186,3 +187,48 @@ def test_backend_fallback_to_env_var(self, mock_get_variable, mock_env_get, mock
# mock_env is only called when LocalFilesystemBackend doesn't have it
mock_env_get.assert_called()
assert var == "fake_value"


class TestVariableOutsideTaskContext:
"""Tests for Variable operations when SUPERVISOR_COMMS is None (outside task execution context)."""

@mock.patch("airflow.secrets.environment_variables.EnvironmentVariablesBackend.get_variable")
def test_get_with_env_var_works_without_supervisor_comms(self, mock_env_get, monkeypatch):
"""Variable.get() should still work via EnvironmentVariablesBackend when SUPERVISOR_COMMS is None."""
from airflow.sdk.execution_time import task_runner

monkeypatch.setattr(task_runner, "SUPERVISOR_COMMS", None)
mock_env_get.return_value = "env_value"

result = Variable.get(key="my_env_var")
assert result == "env_value"
mock_env_get.assert_called_once_with(key="my_env_var")

def test_get_not_found_without_supervisor_comms(self, monkeypatch):
"""Variable.get() should raise with a helpful message when variable not found and SUPERVISOR_COMMS is None."""
from airflow.sdk.execution_time import task_runner

monkeypatch.setattr(task_runner, "SUPERVISOR_COMMS", None)

with pytest.raises(AirflowRuntimeError, match="outside a task execution context"):
Variable.get(key="nonexistent_var")

def test_set_without_supervisor_comms(self, monkeypatch):
"""Variable.set() should raise AirflowRuntimeError when SUPERVISOR_COMMS is None."""
from airflow.sdk.execution_time import task_runner
from airflow.sdk.execution_time.context import _set_variable

monkeypatch.setattr(task_runner, "SUPERVISOR_COMMS", None)

with pytest.raises(AirflowRuntimeError, match="Variable.set\\(\\) requires a task execution context"):
_set_variable(key="my_key", value="my_value")

def test_delete_without_supervisor_comms(self, monkeypatch):
"""Variable.delete() should raise AirflowRuntimeError when SUPERVISOR_COMMS is None."""
from airflow.sdk.execution_time import task_runner
from airflow.sdk.execution_time.context import _delete_variable

monkeypatch.setattr(task_runner, "SUPERVISOR_COMMS", None)

with pytest.raises(AirflowRuntimeError, match="Variable.delete\\(\\) requires a task execution context"):
_delete_variable(key="my_key")
Loading