Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ class DagAccessEntity(Enum):
AUDIT_LOG = "AUDIT_LOG"
CODE = "CODE"
DEPENDENCIES = "DEPENDENCIES"
HITL_DETAIL = "HITL_DETAIL"
RUN = "RUN"
TASK = "TASK"
TASK_INSTANCE = "TASK_INSTANCE"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def _get_hitl_detail(
status.HTTP_409_CONFLICT,
]
),
dependencies=[Depends(requires_access_dag(method="GET", access_entity=DagAccessEntity.TASK_INSTANCE))],
dependencies=[Depends(requires_access_dag(method="PUT", access_entity=DagAccessEntity.HITL_DETAIL))],
)
def update_hitl_detail(
dag_id: str,
Expand Down Expand Up @@ -203,7 +203,7 @@ def update_hitl_detail(
status.HTTP_409_CONFLICT,
]
),
dependencies=[Depends(requires_access_dag(method="GET", access_entity=DagAccessEntity.TASK_INSTANCE))],
dependencies=[Depends(requires_access_dag(method="PUT", access_entity=DagAccessEntity.HITL_DETAIL))],
)
def update_mapped_ti_hitl_detail(
dag_id: str,
Expand All @@ -230,7 +230,7 @@ def update_mapped_ti_hitl_detail(
"/{dag_id}/{dag_run_id}/{task_id}",
status_code=status.HTTP_200_OK,
responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]),
dependencies=[Depends(requires_access_dag(method="GET", access_entity=DagAccessEntity.TASK_INSTANCE))],
dependencies=[Depends(requires_access_dag(method="GET", access_entity=DagAccessEntity.HITL_DETAIL))],
)
def get_hitl_detail(
dag_id: str,
Expand All @@ -252,7 +252,7 @@ def get_hitl_detail(
"/{dag_id}/{dag_run_id}/{task_id}/{map_index}",
status_code=status.HTTP_200_OK,
responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]),
dependencies=[Depends(requires_access_dag(method="GET", access_entity=DagAccessEntity.TASK_INSTANCE))],
dependencies=[Depends(requires_access_dag(method="GET", access_entity=DagAccessEntity.HITL_DETAIL))],
)
def get_mapped_ti_hitl_detail(
dag_id: str,
Expand All @@ -274,7 +274,7 @@ def get_mapped_ti_hitl_detail(
@hitl_router.get(
"/",
status_code=status.HTTP_200_OK,
dependencies=[Depends(requires_access_dag(method="GET", access_entity=DagAccessEntity.TASK_INSTANCE))],
dependencies=[Depends(requires_access_dag(method="GET", access_entity=DagAccessEntity.HITL_DETAIL))],
)
def get_hitl_details(
limit: QueryLimit,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@
from airflow.providers.fab.www.security.permissions import RESOURCE_HITL_DETAIL

_MAP_MENU_ITEM_TO_FAB_RESOURCE_TYPE[MenuItem.REQUIRED_ACTIONS] = RESOURCE_HITL_DETAIL
_MAP_DAG_ACCESS_ENTITY_TO_FAB_RESOURCE_TYPE[DagAccessEntity.HITL_DETAIL] = (RESOURCE_HITL_DETAIL,)


class FabAuthManager(BaseAuthManager[User]):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,62 @@
RESOURCE_WEBSITE,
)

from tests_common.test_utils.version_compat import AIRFLOW_V_3_1_PLUS

if AIRFLOW_V_3_1_PLUS:
from airflow.providers.fab.www.security.permissions import RESOURCE_HITL_DETAIL

HITL_ENDPOINT_TESTS = [
# With global permissions on Dags, but no permission on HITL Detail
(
"GET",
DagAccessEntity.HITL_DETAIL,
None,
[(ACTION_CAN_READ, RESOURCE_DAG)],
False,
),
# With global permissions on Dags, but no permission on HITL Detail
(
"PUT",
DagAccessEntity.HITL_DETAIL,
None,
[(ACTION_CAN_READ, RESOURCE_DAG)],
False,
),
# With global permissions on Dags, with read permission on HITL Detail
(
"GET",
DagAccessEntity.HITL_DETAIL,
None,
[(ACTION_CAN_READ, RESOURCE_DAG), (ACTION_CAN_READ, RESOURCE_HITL_DETAIL)],
True,
),
# With global permissions on Dags, with read permission on HITL Detail, but wrong method
(
"PUT",
DagAccessEntity.HITL_DETAIL,
None,
[(ACTION_CAN_READ, RESOURCE_DAG), (ACTION_CAN_READ, RESOURCE_HITL_DETAIL)],
False,
),
# With global permissions on Dags, with write permission on HITL Detail, but wrong method
(
"GET",
DagAccessEntity.HITL_DETAIL,
None,
[(ACTION_CAN_READ, RESOURCE_DAG), (ACTION_CAN_EDIT, RESOURCE_HITL_DETAIL)],
False,
),
# With global permissions on Dags, with edit permission on HITL Detail
(
"PUT",
DagAccessEntity.HITL_DETAIL,
None,
[(ACTION_CAN_READ, RESOURCE_DAG), (ACTION_CAN_EDIT, RESOURCE_HITL_DETAIL)],
True,
),
]

if TYPE_CHECKING:
from airflow.api_fastapi.auth.managers.base_auth_manager import ResourceMethod

Expand Down Expand Up @@ -468,6 +524,36 @@ def test_is_authorized_dag(
)
assert result == expected_result

@pytest.mark.skipif(
AIRFLOW_V_3_1_PLUS is not True, reason="HITL test will be skipped if Airflow version < 3.1.0"
)
@pytest.mark.parametrize(
"method, dag_access_entity, dag_details, user_permissions, expected_result",
HITL_ENDPOINT_TESTS if AIRFLOW_V_3_1_PLUS else [],
)
@mock.patch.object(FabAuthManager, "get_authorized_dag_ids")
def test_is_authorized_dag_hitl_detail(
self,
mock_get_authorized_dag_ids,
method,
dag_access_entity,
dag_details,
user_permissions,
expected_result,
auth_manager_with_appbuilder,
):
dag_permissions = [perm[1] for perm in user_permissions if perm[1].startswith("DAG:")]
dag_ids = {perm.replace("DAG:", "") for perm in dag_permissions}
mock_get_authorized_dag_ids.return_value = dag_ids

user = Mock()
user.perms = user_permissions
user.id = 1
result = auth_manager_with_appbuilder.is_authorized_dag(
method=method, access_entity=dag_access_entity, details=dag_details, user=user
)
assert result == expected_result

@pytest.mark.parametrize(
"access_view, user_permissions, expected_result",
[
Expand Down
Loading