Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions task-sdk/src/airflow/sdk/execution_time/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,12 +112,12 @@ def _convert_variable_result_to_variable(var_result: VariableResult, deserialize


def _get_connection(conn_id: str) -> Connection:
from airflow.sdk.execution_time.supervisor import SECRETS_BACKEND
from airflow.sdk.execution_time.supervisor import ensure_secrets_backend_loaded
# TODO: check cache first
# enabled only if SecretCache.init() has been called first

# iterate over configured backends if not in cache (or expired)
for secrets_backend in SECRETS_BACKEND:
for secrets_backend in ensure_secrets_backend_loaded():
try:
conn = secrets_backend.get_connection(conn_id=conn_id)
if conn:
Expand Down Expand Up @@ -155,11 +155,11 @@ def _get_connection(conn_id: str) -> Connection:
def _get_variable(key: str, deserialize_json: bool) -> Any:
# TODO: check cache first
# enabled only if SecretCache.init() has been called first
from airflow.sdk.execution_time.supervisor import SECRETS_BACKEND
from airflow.sdk.execution_time.supervisor import ensure_secrets_backend_loaded

var_val = None
# iterate over backends if not in cache (or expired)
for secrets_backend in SECRETS_BACKEND:
for secrets_backend in ensure_secrets_backend_loaded():
try:
var_val = secrets_backend.get_variable(key=key) # type: ignore[assignment]
if var_val is not None:
Expand Down
12 changes: 4 additions & 8 deletions task-sdk/src/airflow/sdk/execution_time/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@
from airflow.typing_compat import Self


__all__ = ["ActivitySubprocess", "WatchedSubprocess", "supervise", "SECRETS_BACKEND"]
__all__ = ["ActivitySubprocess", "WatchedSubprocess", "supervise"]

log: FilteringBoundLogger = structlog.get_logger(logger_name="supervisor")

Expand All @@ -124,8 +124,6 @@
TerminalTIState.SUCCESS,
]

SECRETS_BACKEND: list[BaseSecretsBackend] = []


@overload
def mkpipe() -> tuple[socket, socket]: ...
Expand Down Expand Up @@ -1070,14 +1068,12 @@ def forward_to_log(
log.log(level, msg, chan=chan)


def initialize_secrets_backend_on_workers():
def ensure_secrets_backend_loaded() -> list[BaseSecretsBackend]:
"""Initialize the secrets backend on workers."""
from airflow.configuration import ensure_secrets_loaded
from airflow.secrets import DEFAULT_SECRETS_SEARCH_PATH_WORKERS

global SECRETS_BACKEND
SECRETS_BACKEND = ensure_secrets_loaded(default_backends=DEFAULT_SECRETS_SEARCH_PATH_WORKERS)
log.debug("Initialized secrets backend on workers", secrets_backend=SECRETS_BACKEND)
return ensure_secrets_loaded(default_backends=DEFAULT_SECRETS_SEARCH_PATH_WORKERS)


def register_secrets_masker():
Expand Down Expand Up @@ -1145,7 +1141,7 @@ def supervise(
processors = logging_processors(enable_pretty_log=pretty_logs)[0]
logger = structlog.wrap_logger(underlying_logger, processors=processors, logger_name="task").bind()

initialize_secrets_backend_on_workers()
ensure_secrets_backend_loaded()

register_secrets_masker()

Expand Down
4 changes: 0 additions & 4 deletions task-sdk/tests/task_sdk/definitions/test_connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from airflow.exceptions import AirflowException
from airflow.sdk import Connection
from airflow.sdk.execution_time.comms import ConnectionResult
from airflow.sdk.execution_time.supervisor import initialize_secrets_backend_on_workers
from airflow.secrets import DEFAULT_SECRETS_SEARCH_PATH_WORKERS

from tests_common.test_utils.config import conf_vars
Expand Down Expand Up @@ -112,15 +111,13 @@ def test_get_connection_secrets_backend(self, mock_supervisor_comms, tmp_path):
("workers", "secrets_backend_kwargs"): f'{{"connections_file_path": "{path}"}}',
}
):
initialize_secrets_backend_on_workers()
retrieved_conn = Connection.get(conn_id="CONN_A")
assert retrieved_conn is not None
assert retrieved_conn.conn_id == "CONN_A"

@mock.patch("airflow.secrets.environment_variables.EnvironmentVariablesBackend.get_connection")
def test_get_connection_env_var(self, mock_env_get, mock_supervisor_comms):
"""Tests getting a connection from environment variable."""
initialize_secrets_backend_on_workers()
mock_env_get.return_value = Connection(conn_id="something", conn_type="some-type") # return None
Connection.get("something")
mock_env_get.assert_called_once_with(conn_id="something")
Expand All @@ -135,7 +132,6 @@ def test_get_connection_env_var(self, mock_env_get, mock_supervisor_comms):
@mock.patch("airflow.secrets.environment_variables.EnvironmentVariablesBackend.get_connection")
def test_backend_fallback_to_env_var(self, mock_get_connection, mock_env_get, mock_supervisor_comms):
"""Tests if connection retrieval falls back to environment variable backend if not found in secrets backend."""
initialize_secrets_backend_on_workers()
mock_get_connection.return_value = None
mock_env_get.return_value = Connection(conn_id="something", conn_type="some-type")

Expand Down
4 changes: 0 additions & 4 deletions task-sdk/tests/task_sdk/definitions/test_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from airflow.configuration import initialize_secrets_backends
from airflow.sdk import Variable
from airflow.sdk.execution_time.comms import VariableResult
from airflow.sdk.execution_time.supervisor import initialize_secrets_backend_on_workers
from airflow.secrets import DEFAULT_SECRETS_SEARCH_PATH_WORKERS

from tests_common.test_utils.config import conf_vars
Expand Down Expand Up @@ -72,15 +71,13 @@ def test_var_get_from_secrets_found(self, mock_supervisor_comms, tmp_path):
("workers", "secrets_backend_kwargs"): f'{{"variables_file_path": "{path}"}}',
}
):
initialize_secrets_backend_on_workers()
retrieved_var = Variable.get(key="VAR_A")
assert retrieved_var is not None
assert retrieved_var == "some_value"

@mock.patch("airflow.secrets.environment_variables.EnvironmentVariablesBackend.get_variable")
def test_get_variable_env_var(self, mock_env_get, mock_supervisor_comms):
"""Tests getting a variable from environment variable."""
initialize_secrets_backend_on_workers()
mock_env_get.return_value = "fake_value"
Variable.get(key="fake_var_key")
mock_env_get.assert_called_once_with(key="fake_var_key")
Expand All @@ -97,7 +94,6 @@ def test_get_variable_env_var(self, mock_env_get, mock_supervisor_comms):
@mock.patch("airflow.secrets.environment_variables.EnvironmentVariablesBackend.get_variable")
def test_backend_fallback_to_env_var(self, mock_get_variable, mock_env_get, mock_supervisor_comms):
"""Tests if variable retrieval falls back to environment variable backend if not found in secrets backend."""
initialize_secrets_backend_on_workers()
mock_get_variable.return_value = None
mock_env_get.return_value = "fake_value"

Expand Down