Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 9 additions & 23 deletions airflow/auth/managers/base_auth_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from airflow.utils.session import NEW_SESSION, provide_session

if TYPE_CHECKING:
from collections.abc import Container, Sequence
from collections.abc import Sequence

from fastapi import FastAPI
from sqlalchemy.orm import Session
Expand Down Expand Up @@ -331,7 +331,7 @@ def get_permitted_dag_ids(
self,
*,
user: T,
methods: Container[ResourceMethod] | None = None,
method: ResourceMethod = "GET",
session: Session = NEW_SESSION,
) -> set[str]:
"""
Expand All @@ -342,45 +342,31 @@ def get_permitted_dag_ids(
implementation to provide a more efficient implementation.

:param user: the user
:param methods: whether filter readable or writable
:param method: the method to filter on
:param session: the session
"""
dag_ids = {dag.dag_id for dag in session.execute(select(DagModel.dag_id))}
return self.filter_permitted_dag_ids(dag_ids=dag_ids, methods=methods, user=user)
return self.filter_permitted_dag_ids(dag_ids=dag_ids, method=method, user=user)

def filter_permitted_dag_ids(
self,
*,
dag_ids: set[str],
user: T,
methods: Container[ResourceMethod] | None = None,
method: ResourceMethod = "GET",
) -> set[str]:
"""
Filter readable or writable DAGs for user.

:param dag_ids: the list of DAG ids
:param user: the user
:param methods: whether filter readable or writable
:param method: the method to filter on
"""
if not methods:
methods = ["PUT", "GET"]

if ("GET" in methods and self.is_authorized_dag(method="GET", user=user)) or (
"PUT" in methods and self.is_authorized_dag(method="PUT", user=user)
):
# If user is authorized to read/edit all DAGs, return all DAGs
return dag_ids
def _is_permitted_dag_id(method: ResourceMethod, dag_id: str):
return self.is_authorized_dag(method=method, details=DagDetails(id=dag_id), user=user)

def _is_permitted_dag_id(method: ResourceMethod, methods: Container[ResourceMethod], dag_id: str):
return method in methods and self.is_authorized_dag(
method=method, details=DagDetails(id=dag_id), user=user
)

return {
dag_id
for dag_id in dag_ids
if _is_permitted_dag_id("GET", methods, dag_id) or _is_permitted_dag_id("PUT", methods, dag_id)
}
return {dag_id for dag_id in dag_ids if _is_permitted_dag_id(method, dag_id)}

@staticmethod
def get_cli_commands() -> list[CLICommand]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import argparse
from collections import defaultdict
from collections.abc import Container, Sequence
from collections.abc import Sequence
from functools import cached_property
from typing import TYPE_CHECKING, Any, cast

Expand Down Expand Up @@ -283,23 +283,18 @@ def filter_permitted_dag_ids(
*,
dag_ids: set[str],
user: AwsAuthManagerUser,
methods: Container[ResourceMethod] | None = None,
method: ResourceMethod = "GET",
):
if not methods:
methods = ["PUT", "GET"]

requests: dict[str, dict[ResourceMethod, IsAuthorizedRequest]] = defaultdict(dict)
requests_list: list[IsAuthorizedRequest] = []
for dag_id in dag_ids:
for method in ["GET", "PUT"]:
if method in methods:
request: IsAuthorizedRequest = {
"method": cast("ResourceMethod", method),
"entity_type": AvpEntities.DAG,
"entity_id": dag_id,
}
requests[dag_id][cast("ResourceMethod", method)] = request
requests_list.append(request)
request: IsAuthorizedRequest = {
"method": method,
"entity_type": AvpEntities.DAG,
"entity_id": dag_id,
}
requests[dag_id][method] = request
requests_list.append(request)

batch_is_authorized_results = self.avp_facade.get_batch_is_authorized_results(
requests=requests_list, user=user
Expand All @@ -311,16 +306,7 @@ def _has_access_to_dag(request: IsAuthorizedRequest):
)
return result["decision"] == "ALLOW"

return {
dag_id
for dag_id in dag_ids
if (
"GET" in methods
and _has_access_to_dag(requests[dag_id]["GET"])
or "PUT" in methods
and _has_access_to_dag(requests[dag_id]["PUT"])
)
}
return {dag_id for dag_id in dag_ids if _has_access_to_dag(requests[dag_id][method])}

def get_url_login(self, **kwargs) -> str:
return f"{self.apiserver_endpoint}/auth/login"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -445,42 +445,78 @@ def test_batch_is_authorized_variable(
assert result

@pytest.mark.parametrize(
"methods, user",
"method, user, expected_result",
[
(None, AwsAuthManagerUser(user_id="test_user_id", groups=[])),
(["PUT", "GET"], AwsAuthManagerUser(user_id="test_user_id", groups=[])),
("GET", AwsAuthManagerUser(user_id="test_user_id1", groups=[]), {"dag_1"}),
("PUT", AwsAuthManagerUser(user_id="test_user_id1", groups=[]), set()),
("GET", AwsAuthManagerUser(user_id="test_user_id2", groups=[]), set()),
("PUT", AwsAuthManagerUser(user_id="test_user_id2", groups=[]), {"dag_2"}),
],
)
def test_filter_permitted_dag_ids(self, methods, user, auth_manager, test_user):
def test_filter_permitted_dag_ids(self, method, user, auth_manager, test_user, expected_result):
dag_ids = {"dag_1", "dag_2"}
# test_user_id1 has GET permissions on dag_1
# test_user_id2 has PUT permissions on dag_2
batch_is_authorized_output = [
{
"request": {
"principal": {"entityType": "Airflow::User", "entityId": "test_user_id"},
"principal": {"entityType": "Airflow::User", "entityId": "test_user_id1"},
"action": {"actionType": "Airflow::Action", "actionId": "Dag.GET"},
"resource": {"entityType": "Airflow::Dag", "entityId": "dag_1"},
},
"decision": "ALLOW",
},
{
"request": {
"principal": {"entityType": "Airflow::User", "entityId": "test_user_id1"},
"action": {"actionType": "Airflow::Action", "actionId": "Dag.PUT"},
"resource": {"entityType": "Airflow::Dag", "entityId": "dag_1"},
},
"decision": "DENY",
},
{
"request": {
"principal": {"entityType": "Airflow::User", "entityId": "test_user_id"},
"principal": {"entityType": "Airflow::User", "entityId": "test_user_id1"},
"action": {"actionType": "Airflow::Action", "actionId": "Dag.GET"},
"resource": {"entityType": "Airflow::Dag", "entityId": "dag_2"},
},
"decision": "DENY",
},
{
"request": {
"principal": {"entityType": "Airflow::User", "entityId": "test_user_id1"},
"action": {"actionType": "Airflow::Action", "actionId": "Dag.PUT"},
"resource": {"entityType": "Airflow::Dag", "entityId": "dag_2"},
},
"decision": "DENY",
},
{
"request": {
"principal": {"entityType": "Airflow::User", "entityId": "test_user_id2"},
"action": {"actionType": "Airflow::Action", "actionId": "Dag.GET"},
"resource": {"entityType": "Airflow::Dag", "entityId": "dag_1"},
},
"decision": "DENY",
},
{
"request": {
"principal": {"entityType": "Airflow::User", "entityId": "test_user_id2"},
"action": {"actionType": "Airflow::Action", "actionId": "Dag.PUT"},
"resource": {"entityType": "Airflow::Dag", "entityId": "dag_1"},
},
"decision": "DENY",
},
{
"request": {
"principal": {"entityType": "Airflow::User", "entityId": "test_user_id"},
"principal": {"entityType": "Airflow::User", "entityId": "test_user_id2"},
"action": {"actionType": "Airflow::Action", "actionId": "Dag.GET"},
"resource": {"entityType": "Airflow::Dag", "entityId": "dag_2"},
},
"decision": "DENY",
},
{
"request": {
"principal": {"entityType": "Airflow::User", "entityId": "test_user_id"},
"principal": {"entityType": "Airflow::User", "entityId": "test_user_id2"},
"action": {"actionType": "Airflow::Action", "actionId": "Dag.PUT"},
"resource": {"entityType": "Airflow::Dag", "entityId": "dag_2"},
},
Expand All @@ -493,12 +529,12 @@ def test_filter_permitted_dag_ids(self, methods, user, auth_manager, test_user):

result = auth_manager.filter_permitted_dag_ids(
dag_ids=dag_ids,
methods=methods,
method=method,
user=user,
)

auth_manager.avp_facade.get_batch_is_authorized_results.assert_called()
assert result == {"dag_2"}
assert result == expected_result

def test_get_url_login(self, auth_manager):
result = auth_manager.get_url_login()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from __future__ import annotations

import argparse
from collections.abc import Container
from functools import cached_property
from pathlib import Path
from typing import TYPE_CHECKING, Any
Expand Down Expand Up @@ -58,6 +57,7 @@
USERS_COMMANDS,
)
from airflow.providers.fab.auth_manager.models import Permission, Role, User
from airflow.providers.fab.auth_manager.models.anonymous_user import AnonymousUser
from airflow.providers.fab.www.app import create_app
from airflow.providers.fab.www.constants import SWAGGER_BUNDLE, SWAGGER_ENABLED
from airflow.providers.fab.www.extensions.init_views import _CustomErrorRequestBodyValidator, _LazyResolver
Expand Down Expand Up @@ -355,30 +355,24 @@ def get_permitted_dag_ids(
self,
*,
user: User,
methods: Container[ResourceMethod] | None = None,
method: ResourceMethod = "GET",
session: Session = NEW_SESSION,
) -> set[str]:
if not methods:
methods = ["PUT", "GET"]

if not self.is_logged_in():
roles = user.roles
else:
if ("GET" in methods and self.is_authorized_dag(method="GET", user=user)) or (
"PUT" in methods and self.is_authorized_dag(method="PUT", user=user)
):
# If user is authorized to read/edit all DAGs, return all DAGs
return {dag.dag_id for dag in session.execute(select(DagModel.dag_id))}
user_query = session.scalar(
select(User)
.options(
joinedload(User.roles)
.subqueryload(Role.permissions)
.options(joinedload(Permission.action), joinedload(Permission.resource))
)
.where(User.id == user.id)
if self._is_authorized(method=method, resource_type=RESOURCE_DAG, user=user):
# If user is authorized to access all DAGs, return all DAGs
return {dag.dag_id for dag in session.execute(select(DagModel.dag_id))}
if isinstance(user, AnonymousUser):
return set()
user_query = session.scalar(
select(User)
.options(
joinedload(User.roles)
.subqueryload(Role.permissions)
.options(joinedload(Permission.action), joinedload(Permission.resource))
)
roles = user_query.roles
.where(User.id == user.id)
)
roles = user_query.roles

map_fab_action_name_to_method_name = get_method_from_fab_action_map()
resources = set()
Expand All @@ -387,7 +381,7 @@ def get_permitted_dag_ids(
action = permission.action.name
if (
action in map_fab_action_name_to_method_name
and map_fab_action_name_to_method_name[action] in methods
and map_fab_action_name_to_method_name[action] == method
):
resource = permission.resource.name
if resource == permissions.RESOURCE_DAG:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -973,12 +973,12 @@ def create_db(self):
@staticmethod
def get_readable_dag_ids(user=None) -> set[str]:
"""Get the DAG IDs readable by authenticated user."""
return get_auth_manager().get_permitted_dag_ids(methods=["GET"], user=user)
return get_auth_manager().get_permitted_dag_ids(user=user)

@staticmethod
def get_editable_dag_ids(user=None) -> set[str]:
"""Get the DAG IDs editable by authenticated user."""
return get_auth_manager().get_permitted_dag_ids(methods=["PUT"], user=user)
return get_auth_manager().get_permitted_dag_ids(method="PUT", user=user)

def can_access_some_dags(self, action: str, dag_id: str | None = None) -> bool:
"""Check if user has read or write access to some dags."""
Expand Down
Loading