Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from typing import Annotated

from fastapi import Depends, HTTPException, Query, status
from sqlalchemy import select
from sqlalchemy import delete, select

from airflow.api_fastapi.common.db.common import SessionDep, paginated_select
from airflow.api_fastapi.common.parameters import (
Expand Down Expand Up @@ -62,7 +62,11 @@ def delete_variable(
session: SessionDep,
):
"""Delete a variable entry."""
if Variable.delete(variable_key, session) == 0:
# Like the other endpoints (get, patch), we do not use Variable.delete/get/set here because these methods
# are intended to be used in task execution environment (execution API)
result = session.execute(delete(Variable).where(Variable.key == variable_key))
rows = getattr(result, "rowcount", 0) or 0
if rows == 0:
raise HTTPException(
status.HTTP_404_NOT_FOUND, f"The Variable with key: `{variable_key}` was not found"
)
Expand Down
22 changes: 22 additions & 0 deletions airflow-core/src/airflow/api_fastapi/execution_api/deps.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,15 @@
import svcs
from fastapi import Depends, HTTPException, Request, status
from fastapi.security import HTTPBearer
from sqlalchemy import select

from airflow.api_fastapi.auth.tokens import JWTValidator
from airflow.api_fastapi.common.db.common import AsyncSessionDep
from airflow.api_fastapi.execution_api.datamodels.token import TIToken
from airflow.configuration import conf
from airflow.models import DagModel, TaskInstance
from airflow.models.dagbundle import DagBundleModel
from airflow.models.team import Team

log = structlog.get_logger(logger_name=__name__)

Expand Down Expand Up @@ -95,3 +101,19 @@ async def __call__( # type: ignore[override]

# This checks that the UUID in the url matches the one in the token for us.
JWTBearerTIPathDep = Depends(JWTBearer(path_param_name="task_instance_id"))


async def get_team_name_dep(session: AsyncSessionDep, token=JWTBearerDep) -> str | None:
"""Return the team name associated to the task (if any)."""
if not conf.getboolean("core", "multi_team"):
return None

stmt = (
select(Team.name)
.select_from(TaskInstance)
.join(DagModel, DagModel.dag_id == TaskInstance.dag_id)
.join(DagBundleModel, DagBundleModel.name == DagModel.bundle_name)
.join(DagBundleModel.teams)
.where(TaskInstance.id == str(token.id))
)
return await session.scalar(stmt)
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,15 @@
from __future__ import annotations

import logging
from typing import Annotated

from fastapi import APIRouter, Depends, HTTPException, Path, Request, status

from airflow.api_fastapi.execution_api.datamodels.variable import (
VariablePostBody,
VariableResponse,
)
from airflow.api_fastapi.execution_api.deps import JWTBearerDep
from airflow.api_fastapi.execution_api.deps import JWTBearerDep, get_team_name_dep
from airflow.models.variable import Variable


Expand Down Expand Up @@ -61,13 +62,15 @@ async def has_variable_access(
status.HTTP_403_FORBIDDEN: {"description": "Task does not have access to the variable"},
},
)
def get_variable(variable_key: str) -> VariableResponse:
def get_variable(
variable_key: str, team_name: Annotated[str | None, Depends(get_team_name_dep)]
) -> VariableResponse:
"""Get an Airflow Variable."""
if not variable_key:
raise HTTPException(status.HTTP_404_NOT_FOUND, detail="Not Found")

try:
variable_value = Variable.get(variable_key)
variable_value = Variable.get(variable_key, team_name=team_name)
except KeyError:
raise HTTPException(
status.HTTP_404_NOT_FOUND,
Expand All @@ -88,12 +91,14 @@ def get_variable(variable_key: str) -> VariableResponse:
status.HTTP_403_FORBIDDEN: {"description": "Task does not have access to the variable"},
},
)
def put_variable(variable_key: str, body: VariablePostBody):
def put_variable(
variable_key: str, body: VariablePostBody, team_name: Annotated[str | None, Depends(get_team_name_dep)]
):
"""Set an Airflow Variable."""
if not variable_key:
raise HTTPException(status.HTTP_404_NOT_FOUND, detail="Not Found")

Variable.set(key=variable_key, value=body.value, description=body.description)
Variable.set(key=variable_key, value=body.value, description=body.description, team_name=team_name)
return {"message": "Variable successfully set"}


Expand All @@ -105,9 +110,9 @@ def put_variable(variable_key: str, body: VariablePostBody):
status.HTTP_403_FORBIDDEN: {"description": "Task does not have access to the variable"},
},
)
def delete_variable(variable_key: str):
def delete_variable(variable_key: str, team_name: Annotated[str | None, Depends(get_team_name_dep)]):
"""Delete an Airflow Variable."""
if not variable_key:
raise HTTPException(status.HTTP_404_NOT_FOUND, detail="Not Found")

Variable.delete(key=variable_key)
Variable.delete(key=variable_key, team_name=team_name)
4 changes: 4 additions & 0 deletions airflow-core/src/airflow/cli/commands/team_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

from __future__ import annotations

import re

from sqlalchemy import func, select
from sqlalchemy.exc import IntegrityError

Expand Down Expand Up @@ -50,6 +52,8 @@ def _extract_team_name(args):
team_name = args.name.strip()
if not team_name:
raise SystemExit("Team name cannot be empty")
if not re.match(r"^[a-zA-Z0-9_-]{3,50}$", team_name):
raise SystemExit("Invalid team name: must match regex ^[a-zA-Z0-9_-]{3,50}$")
return team_name


Expand Down
59 changes: 45 additions & 14 deletions airflow-core/src/airflow/models/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import warnings
from typing import TYPE_CHECKING, Any

from sqlalchemy import Boolean, ForeignKey, Integer, String, Text, delete, select
from sqlalchemy import Boolean, ForeignKey, Integer, String, Text, delete, or_, select
from sqlalchemy.dialects.mysql import MEDIUMTEXT
from sqlalchemy.orm import Mapped, declared_attr, reconstructor, synonym

Expand Down Expand Up @@ -139,13 +139,15 @@ def get(
key: str,
default_var: Any = __NO_DEFAULT_SENTINEL,
deserialize_json: bool = False,
team_name: str | None = None,
) -> Any:
"""
Get a value for an Airflow Variable Key.

:param key: Variable Key
:param default_var: Default value of the Variable if the Variable doesn't exist
:param deserialize_json: Deserialize the value to a Python dict
:param team_name: Team name associated to the task trying to access the variable (if any)
"""
# TODO: This is not the best way of having compat, but it's "better than erroring" for now. This still
# means SQLA etc is loaded, but we can't avoid that unless/until we add import shims as a big
Expand All @@ -169,7 +171,12 @@ def get(

return var_val

var_val = Variable.get_variable_from_secrets(key=key)
if team_name and not conf.getboolean("core", "multi_team"):
raise ValueError(
"Multi-team mode is not configured in the Airflow environment but the task trying to access the variable belongs to a team"
)

var_val = Variable.get_variable_from_secrets(key=key, team_name=team_name)
if var_val is None:
if default_var is not cls.__NO_DEFAULT_SENTINEL:
return default_var
Expand Down Expand Up @@ -319,6 +326,7 @@ def update(
key: str,
value: Any,
serialize_json: bool = False,
team_name: str | None = None,
session: Session | None = None,
) -> None:
"""
Expand All @@ -327,6 +335,7 @@ def update(
:param key: Variable Key
:param value: Value to set for the Variable
:param serialize_json: Serialize the value to a JSON string
:param team_name: Team name associated to the variable (if any)
:param session: optional session, use if provided or create a new one
"""
# TODO: This is not the best way of having compat, but it's "better than erroring" for now. This still
Expand All @@ -352,9 +361,14 @@ def update(
)
return

if team_name and not conf.getboolean("core", "multi_team"):
raise ValueError(
"Multi-team mode is not configured in the Airflow environment. To assign a team to a variable, multi-mode must be enabled."
)

Variable.check_for_write_conflict(key=key)

if Variable.get_variable_from_secrets(key=key) is None:
if Variable.get_variable_from_secrets(key=key, team_name=team_name) is None:
raise KeyError(f"Variable {key} does not exist")

ctx: contextlib.AbstractContextManager
Expand All @@ -364,7 +378,11 @@ def update(
ctx = create_session()

with ctx as session:
obj = session.scalar(select(Variable).where(Variable.key == key))
obj = session.scalar(
select(Variable).where(
Variable.key == key, or_(Variable.team_name == team_name, Variable.team_name.is_(None))
)
)
if obj is None:
raise AttributeError(f"Variable {key} does not exist in the Database and cannot be updated.")

Expand All @@ -377,11 +395,12 @@ def update(
)

@staticmethod
def delete(key: str, session: Session | None = None) -> int:
def delete(key: str, team_name: str | None = None, session: Session | None = None) -> int:
"""
Delete an Airflow Variable for a given key.

:param key: Variable Keys
:param team_name: Team name associated to the task trying to delete the variable (if any)
:param session: optional session, use if provided or create a new one
"""
# TODO: This is not the best way of having compat, but it's "better than erroring" for now. This still
Expand All @@ -404,14 +423,23 @@ def delete(key: str, session: Session | None = None) -> int:
)
return 1

if team_name and not conf.getboolean("core", "multi_team"):
raise ValueError(
"Multi-team mode is not configured in the Airflow environment but the task trying to delete the variable belongs to a team"
)

ctx: contextlib.AbstractContextManager
if session is not None:
ctx = contextlib.nullcontext(session)
else:
ctx = create_session()

with ctx as session:
result = session.execute(delete(Variable).where(Variable.key == key))
result = session.execute(
delete(Variable).where(
Variable.key == key, or_(Variable.team_name == team_name, Variable.team_name.is_(None))
)
)
rows = getattr(result, "rowcount", 0) or 0
SecretCache.invalidate_variable(key)
return rows
Expand Down Expand Up @@ -458,25 +486,28 @@ def check_for_write_conflict(key: str) -> None:
return None

@staticmethod
def get_variable_from_secrets(key: str) -> str | None:
def get_variable_from_secrets(key: str, team_name: str | None = None) -> str | None:
"""
Get Airflow Variable by iterating over all Secret Backends.

:param key: Variable Key
:param team_name: Team name associated to the task trying to access the variable (if any)
:return: Variable Value
"""
# check cache first
# enabled only if SecretCache.init() has been called first
try:
return SecretCache.get_variable(key)
except SecretCache.NotPresentException:
pass # continue business
# Disable cache if the variable belongs to a team. We might enable it later
if not team_name:
# check cache first
# enabled only if SecretCache.init() has been called first
try:
return SecretCache.get_variable(key)
except SecretCache.NotPresentException:
pass # continue business

var_val = None
# iterate over backends if not in cache (or expired)
for secrets_backend in ensure_secrets_loaded():
try:
var_val = secrets_backend.get_variable(key=key)
var_val = secrets_backend.get_variable(key=key, team_name=team_name)
if var_val is not None:
break
except Exception:
Expand Down
9 changes: 8 additions & 1 deletion airflow-core/src/airflow/secrets/environment_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,18 @@ class EnvironmentVariablesBackend(BaseSecretsBackend):
def get_conn_value(self, conn_id: str) -> str | None:
return os.environ.get(CONN_ENV_PREFIX + conn_id.upper())

def get_variable(self, key: str) -> str | None:
def get_variable(self, key: str, team_name: str | None = None) -> str | None:
"""
Get Airflow Variable from Environment Variable.

:param key: Variable Key
:param team_name: Team name associated to the task trying to access the variable (if any)
:return: Variable Value
"""
if team_name and (
team_var := os.environ.get(f"{VAR_ENV_PREFIX}_{team_name.upper()}___" + key.upper())
):
# Format to set a team specific variable: AIRFLOW_VAR__<TEAM_ID>___<VAR_KEY>
return team_var

return os.environ.get(VAR_ENV_PREFIX + key.upper())
2 changes: 1 addition & 1 deletion airflow-core/src/airflow/secrets/local_filesystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ def get_connection(self, conn_id: str) -> Connection | None:
return self._local_connections[conn_id]
return None

def get_variable(self, key: str) -> str | None:
def get_variable(self, key: str, team_name: str | None = None) -> str | None:
return self._local_variables.get(key)

def get_config(self, key: str) -> str | None:
Expand Down
13 changes: 10 additions & 3 deletions airflow-core/src/airflow/secrets/metastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from typing import TYPE_CHECKING

from sqlalchemy import select
from sqlalchemy import or_, select

from airflow.secrets import BaseSecretsBackend
from airflow.utils.session import NEW_SESSION, provide_session
Expand Down Expand Up @@ -51,17 +51,24 @@ def get_connection(self, conn_id: str, session: Session = NEW_SESSION) -> Connec
return conn

@provide_session
def get_variable(self, key: str, session: Session = NEW_SESSION) -> str | None:
def get_variable(
self, key: str, team_name: str | None = None, session: Session = NEW_SESSION
) -> str | None:
"""
Get Airflow Variable from Metadata DB.

:param key: Variable Key
:param team_name: Team name associated to the task trying to access the variable (if any)
:param session: SQLAlchemy Session
:return: Variable Value
"""
from airflow.models import Variable

var_value = session.scalar(select(Variable).where(Variable.key == key).limit(1))
var_value = session.scalar(
select(Variable)
.where(Variable.key == key, or_(Variable.team_name == team_name, Variable.team_name.is_(None)))
.limit(1)
)
session.expunge_all()
if var_value:
return var_value.val
Expand Down
Loading
Loading