Skip to content

Commit

Permalink
AIP-72: Add "Get Variable" endpoint for Execution API (apache#43832)
Browse files Browse the repository at this point in the history
This commit introduces a new endpoint, `/execution/variable/{variable_key}`, in the Execution API to retrieve Variables details.

Same as the Connections PR, it uses a placeholder `check_connection_access` function to validate task permissions for each request.
  • Loading branch information
kaxil authored Nov 8, 2024
1 parent 6c30fc5 commit 01302a1
Show file tree
Hide file tree
Showing 8 changed files with 189 additions and 22 deletions.
9 changes: 9 additions & 0 deletions airflow/api_fastapi/execution_api/datamodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,15 @@ class ConnectionResponse(BaseModel):
extra: str | None


class VariableResponse(BaseModel):
"""Variable schema for responses with fields that are needed for Runtime."""

model_config = ConfigDict(from_attributes=True)

key: str
val: str | None = Field(alias="value")


# TODO: This is a placeholder for Task Identity Token schema.
class TIToken(BaseModel):
"""Task Identity Token."""
Expand Down
11 changes: 5 additions & 6 deletions airflow/api_fastapi/execution_api/routes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,10 @@
from __future__ import annotations

from airflow.api_fastapi.common.router import AirflowRouter
from airflow.api_fastapi.execution_api.routes.connections import connection_router
from airflow.api_fastapi.execution_api.routes.health import health_router
from airflow.api_fastapi.execution_api.routes.task_instance import ti_router
from airflow.api_fastapi.execution_api.routes import connections, health, task_instance, variables

execution_api_router = AirflowRouter()
execution_api_router.include_router(connection_router)
execution_api_router.include_router(health_router)
execution_api_router.include_router(ti_router)
execution_api_router.include_router(connections.router, prefix="/connections", tags=["Connections"])
execution_api_router.include_router(health.router, tags=["Health"])
execution_api_router.include_router(task_instance.router, prefix="/task_instance", tags=["Task Instance"])
execution_api_router.include_router(variables.router, prefix="/variables", tags=["Variables"])
6 changes: 2 additions & 4 deletions airflow/api_fastapi/execution_api/routes/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,7 @@
from airflow.models.connection import Connection

# TODO: Add dependency on JWT token
connection_router = AirflowRouter(
prefix="/connection",
tags=["Connection"],
router = AirflowRouter(
responses={status.HTTP_404_NOT_FOUND: {"description": "Connection not found"}},
)

Expand All @@ -42,7 +40,7 @@ def get_task_token() -> datamodels.TIToken:
return datamodels.TIToken(ti_key="test_key")


@connection_router.get(
@router.get(
"/{connection_id}",
responses={
status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"},
Expand Down
4 changes: 2 additions & 2 deletions airflow/api_fastapi/execution_api/routes/health.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@

from airflow.api_fastapi.common.router import AirflowRouter

health_router = AirflowRouter(tags=["Health"])
router = AirflowRouter()


@health_router.get("/health")
@router.get("/health")
def health() -> dict:
return {"status": "healthy"}
9 changes: 3 additions & 6 deletions airflow/api_fastapi/execution_api/routes/task_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,13 @@
from airflow.utils.state import State

# TODO: Add dependency on JWT token
ti_router = AirflowRouter(
prefix="/task_instance",
tags=["Task Instance"],
)
router = AirflowRouter()


log = logging.getLogger(__name__)


@ti_router.patch(
@router.patch(
"/{task_instance_id}/state",
status_code=status.HTTP_204_NO_CONTENT,
# TODO: Add description to the operation
Expand Down Expand Up @@ -133,7 +130,7 @@ def ti_update_state(
)


@ti_router.put(
@router.put(
"/{task_instance_id}/heartbeat",
status_code=status.HTTP_204_NO_CONTENT,
responses={
Expand Down
87 changes: 87 additions & 0 deletions airflow/api_fastapi/execution_api/routes/variables.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

import logging

from fastapi import Depends, HTTPException, status
from typing_extensions import Annotated

from airflow.api_fastapi.common.router import AirflowRouter
from airflow.api_fastapi.execution_api import datamodels
from airflow.models.variable import Variable

# TODO: Add dependency on JWT token
router = AirflowRouter(
responses={status.HTTP_404_NOT_FOUND: {"description": "Variable not found"}},
)

log = logging.getLogger(__name__)


def get_task_token() -> datamodels.TIToken:
"""TODO: Placeholder for task identity authentication. This should be replaced with actual JWT decoding and validation."""
return datamodels.TIToken(ti_key="test_key")


@router.get(
"/{variable_key}",
responses={
status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"},
status.HTTP_403_FORBIDDEN: {"description": "Task does not have access to the variable"},
},
)
def get_variable(
variable_key: str,
token: Annotated[datamodels.TIToken, Depends(get_task_token)],
) -> datamodels.VariableResponse:
"""Get an Airflow Variable."""
if not has_variable_access(variable_key, token):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail={
"reason": "access_denied",
"message": f"Task does not have access to variable {variable_key}",
},
)

try:
variable_value = Variable.get(variable_key)
except KeyError:
raise HTTPException(
status.HTTP_404_NOT_FOUND,
detail={
"reason": "not_found",
"message": f"Variable with key '{variable_key}' not found",
},
)

return datamodels.VariableResponse(key=variable_key, value=variable_value)


def has_variable_access(variable_key: str, token: datamodels.TIToken) -> bool:
"""Check if the task has access to the variable."""
# TODO: Placeholder for actual implementation

ti_key = token.ti_key
log.debug(
"Checking access for task instance with key '%s' to variable '%s'",
ti_key,
variable_key,
)
return True
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def test_connection_get_from_db(self, client, session):
session.add(connection)
session.commit()

response = client.get("/execution/connection/test_conn")
response = client.get("/execution/connections/test_conn")

assert response.status_code == 200
assert response.json() == {
Expand All @@ -66,7 +66,7 @@ def test_connection_get_from_db(self, client, session):
{"AIRFLOW_CONN_TEST_CONN2": '{"uri": "http://root:admin@localhost:8080/https?headers=header"}'},
)
def test_connection_get_from_env_var(self, client, session):
response = client.get("/execution/connection/test_conn2")
response = client.get("/execution/connections/test_conn2")

assert response.status_code == 200
assert response.json() == {
Expand All @@ -81,7 +81,7 @@ def test_connection_get_from_env_var(self, client, session):
}

def test_connection_get_not_found(self, client):
response = client.get("/execution/connection/non_existent_test_conn")
response = client.get("/execution/connections/non_existent_test_conn")

assert response.status_code == 404
assert response.json() == {
Expand All @@ -95,7 +95,7 @@ def test_connection_get_access_denied(self, client):
with mock.patch(
"airflow.api_fastapi.execution_api.routes.connections.has_connection_access", return_value=False
):
response = client.get("/execution/connection/test_conn")
response = client.get("/execution/connections/test_conn")

# Assert response status code and detail for access denied
assert response.status_code == 403
Expand Down
77 changes: 77 additions & 0 deletions tests/api_fastapi/execution_api/routes/test_variables.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

from unittest import mock

import pytest

from airflow.models.variable import Variable

pytestmark = pytest.mark.db_test


class TestGetVariable:
def test_variable_get_from_db(self, client, session):
Variable.set(key="var1", value="value", session=session)
session.commit()

response = client.get("/execution/variables/var1")

assert response.status_code == 200
assert response.json() == {"key": "var1", "value": "value"}

# Remove connection
Variable.delete(key="var1", session=session)
session.commit()

@mock.patch.dict(
"os.environ",
{"AIRFLOW_VAR_KEY1": "VALUE"},
)
def test_variable_get_from_env_var(self, client, session):
response = client.get("/execution/variables/key1")

assert response.status_code == 200
assert response.json() == {"key": "key1", "value": "VALUE"}

def test_variable_get_not_found(self, client):
response = client.get("/execution/variables/non_existent_var")

assert response.status_code == 404
assert response.json() == {
"detail": {
"message": "Variable with key 'non_existent_var' not found",
"reason": "not_found",
}
}

def test_variable_get_access_denied(self, client):
with mock.patch(
"airflow.api_fastapi.execution_api.routes.variables.has_variable_access", return_value=False
):
response = client.get("/execution/variables/key1")

# Assert response status code and detail for access denied
assert response.status_code == 403
assert response.json() == {
"detail": {
"reason": "access_denied",
"message": "Task does not have access to variable key1",
}
}

0 comments on commit 01302a1

Please sign in to comment.