Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
)
from airflow.api_fastapi.core_api.openapi.exceptions import create_openapi_http_exception_doc
from airflow.api_fastapi.core_api.security import GetUserDep, ReadableTIFilterDep, requires_access_dag
from airflow.api_fastapi.logging.decorators import action_logging
from airflow.models.hitl import HITLDetail as HITLDetailModel
from airflow.models.taskinstance import TaskInstance as TI

Expand Down Expand Up @@ -181,7 +182,10 @@ def _get_hitl_detail(
status.HTTP_409_CONFLICT,
]
),
dependencies=[Depends(requires_access_dag(method="PUT", access_entity=DagAccessEntity.HITL_DETAIL))],
dependencies=[
Depends(requires_access_dag(method="PUT", access_entity=DagAccessEntity.HITL_DETAIL)),
Depends(action_logging()),
],
)
def update_hitl_detail(
dag_id: str,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,19 @@
# under the License.
from __future__ import annotations

import json
from datetime import datetime
from typing import TYPE_CHECKING, Any
from unittest import mock

import pytest
import time_machine
from sqlalchemy import select
from sqlalchemy.orm import Session

from airflow._shared.timezones.timezone import utcnow
from airflow.models.hitl import HITLDetail
from airflow.models.log import Log
from airflow.utils.state import TaskInstanceState

if TYPE_CHECKING:
Expand Down Expand Up @@ -256,19 +259,57 @@ def expected_sample_hitl_detail_dict(sample_ti: TaskInstance) -> dict[str, Any]:
}


@pytest.fixture(autouse=True)
def cleanup_audit_log(session: Session) -> None:
session.query(Log).delete()
session.commit()


def _assert_sample_audit_log(audit_log: Log, map_index: int | None) -> None:
assert audit_log.dag_id == DAG_ID
assert audit_log.task_id == TASK_ID
assert audit_log.run_id == "test"
assert audit_log.map_index is None
assert audit_log.try_number is None
assert audit_log.owner == "test"
assert audit_log.owner_display_name == "test"
assert audit_log.event == "update_hitl_detail"

if TYPE_CHECKING:
assert isinstance(audit_log.extra, str)

expected_extra = {
"chosen_options": ["Approve"],
"params_input": {"input_1": 2},
"method": "PATCH",
}
if map_index is not None:
expected_extra["map_index"] = str(map_index)

assert json.loads(audit_log.extra) == expected_extra


@pytest.fixture
def sample_update_payload() -> dict[str, Any]:
return {"chosen_options": ["Approve"], "params_input": {"input_1": 2}}


class TestUpdateHITLDetailEndpoint:
@pytest.mark.parametrize("query_param", ["", "?map_index=-1"])
@time_machine.travel(datetime(2025, 7, 3, 0, 0, 0), tick=False)
@pytest.mark.usefixtures("sample_hitl_detail")
@pytest.mark.parametrize("map_index", [None, -1])
def test_should_respond_200_with_existing_response(
self,
test_client: TestClient,
sample_ti_url_identifier: str,
query_param: str,
map_index: int | None,
sample_update_payload: dict[str, Any],
session: Session,
) -> None:
query_param = "" if map_index is None else f"?map_index={map_index}"
response = test_client.patch(
f"/hitlDetails/{sample_ti_url_identifier}{query_param}",
json={"chosen_options": ["Approve"], "params_input": {"input_1": 2}},
json=sample_update_payload,
)

assert response.status_code == 200
Expand All @@ -279,19 +320,25 @@ def test_should_respond_200_with_existing_response(
"response_at": "2025-07-03T00:00:00Z",
}

@pytest.mark.parametrize("query_param", ["", "?map_index=-1"])
audit_log = session.scalar(select(Log))
_assert_sample_audit_log(audit_log, map_index=map_index)

@time_machine.travel(datetime(2025, 7, 3, 0, 0, 0), tick=False)
@pytest.mark.usefixtures("sample_hitl_detail_respondent")
@pytest.mark.parametrize("map_index", [None, -1])
def test_should_respond_200_to_respondent_user(
self,
test_client: TestClient,
sample_ti_url_identifier: str,
query_param: str,
map_index: int | None,
sample_update_payload: dict[str, Any],
session: Session,
):
"""Test with an authorized user and the user is a respondent to the task."""
query_param = "" if map_index is None else f"?map_index={map_index}"
response = test_client.patch(
f"/hitlDetails/{sample_ti_url_identifier}{query_param}",
json={"chosen_options": ["Approve"], "params_input": {"input_1": 2}},
json=sample_update_payload,
)

assert response.status_code == 200
Expand All @@ -302,16 +349,20 @@ def test_should_respond_200_to_respondent_user(
"response_at": "2025-07-03T00:00:00Z",
}

audit_log = session.scalar(select(Log))
_assert_sample_audit_log(audit_log, map_index=map_index)

@pytest.mark.parametrize("query_param", ["", "?map_index=-1"])
def test_should_respond_401(
self,
unauthenticated_test_client: TestClient,
sample_ti_url_identifier: str,
sample_update_payload: dict[str, Any],
query_param: str,
) -> None:
response = unauthenticated_test_client.patch(
f"/hitlDetails/{sample_ti_url_identifier}{query_param}",
json={"chosen_options": ["Approve"], "params_input": {"input_1": 2}},
json=sample_update_payload,
)
assert response.status_code == 401

Expand All @@ -320,11 +371,12 @@ def test_should_respond_403(
self,
unauthorized_test_client: TestClient,
sample_ti_url_identifier: str,
sample_update_payload: dict[str, Any],
query_param: str,
) -> None:
response = unauthorized_test_client.patch(
f"/hitlDetails/{sample_ti_url_identifier}{query_param}",
json={"chosen_options": ["Approve"], "params_input": {"input_1": 2}},
json=sample_update_payload,
)
assert response.status_code == 403

Expand All @@ -335,12 +387,13 @@ def test_should_respond_403_to_non_respondent_user(
self,
test_client: TestClient,
sample_ti_url_identifier: str,
sample_update_payload: dict[str, Any],
query_param: str,
):
"""Test with an authorized user but the user is not a respondent to the task."""
response = test_client.patch(
f"/hitlDetails/{sample_ti_url_identifier}{query_param}",
json={"chosen_options": ["Approve"], "params_input": {"input_1": 2}},
json=sample_update_payload,
)
assert response.status_code == 403

Expand All @@ -363,12 +416,13 @@ def test_should_respond_404_without_hitl_detail(
self,
test_client: TestClient,
sample_ti_url_identifier: str,
expected_hitl_detail_not_found_error_msg: str,
sample_update_payload: dict[str, Any],
query_param: str,
expected_hitl_detail_not_found_error_msg: str,
) -> None:
response = test_client.patch(
f"/hitlDetails/{sample_ti_url_identifier}{query_param}",
json={"chosen_options": ["Approve"], "params_input": {"input_1": 2}},
json=sample_update_payload,
)

assert response.status_code == 404
Expand All @@ -381,8 +435,8 @@ def test_should_respond_409(
self,
test_client: TestClient,
sample_ti_url_identifier: str,
sample_ti: TaskInstance,
query_param: str,
sample_ti: TaskInstance,
) -> None:
response = test_client.patch(
f"/hitlDetails/{sample_ti_url_identifier}{query_param}",
Expand All @@ -400,7 +454,7 @@ def test_should_respond_409(

response = test_client.patch(
f"/hitlDetails/{sample_ti_url_identifier}{query_param}",
json={"chosen_options": ["Approve"], "params_input": {"input_1": 2}},
json={"chosen_options": ["Approve"], "params_input": {"input_1": 3}},
)
assert response.status_code == 409
assert response.json() == {
Expand Down
Loading