Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions airflow-core/docs/core-concepts/auth-manager/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -176,10 +176,7 @@ Optional methods recommended to override for optimization

The following methods aren't required to override to have a functional Airflow auth manager. However, it is recommended to override these to make your auth manager faster (and potentially less costly):

* ``batch_is_authorized_connection``: Batch version of ``is_authorized_connection``. If not overridden, it will call ``is_authorized_connection`` for every single item.
* ``batch_is_authorized_dag``: Batch version of ``is_authorized_dag``. If not overridden, it will call ``is_authorized_dag`` for every single item.
* ``batch_is_authorized_pool``: Batch version of ``is_authorized_pool``. If not overridden, it will call ``is_authorized_pool`` for every single item.
* ``batch_is_authorized_variable``: Batch version of ``is_authorized_variable``. If not overridden, it will call ``is_authorized_variable`` for every single item.
* ``get_authorized_dag_ids``: Return the list of DAG IDs the user has access to. If not overridden, it will call ``is_authorized_dag`` for every single DAG available in the environment.

CLI
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,7 @@
from sqlalchemy.orm import Session

from airflow.api_fastapi.auth.managers.models.batch_apis import (
IsAuthorizedConnectionRequest,
IsAuthorizedDagRequest,
IsAuthorizedPoolRequest,
IsAuthorizedVariableRequest,
)
from airflow.api_fastapi.auth.managers.models.resource_details import (
AccessView,
Expand Down Expand Up @@ -307,27 +304,6 @@ def filter_authorized_menu_items(self, menu_items: list[MenuItem], *, user: T) -
:param user: the user
"""

def batch_is_authorized_connection(
self,
requests: Sequence[IsAuthorizedConnectionRequest],
*,
user: T,
) -> bool:
"""
Batch version of ``is_authorized_connection``.

By default, calls individually the ``is_authorized_connection`` API on each item in the list of
requests, which can lead to some poor performance. It is recommended to override this method in the auth
manager implementation to provide a more efficient implementation.

:param requests: a list of requests containing the parameters for ``is_authorized_connection``
:param user: the user to performing the action
"""
return all(
self.is_authorized_connection(method=request["method"], details=request.get("details"), user=user)
for request in requests
)

def batch_is_authorized_dag(
self,
requests: Sequence[IsAuthorizedDagRequest],
Expand All @@ -354,48 +330,6 @@ def batch_is_authorized_dag(
for request in requests
)

def batch_is_authorized_pool(
self,
requests: Sequence[IsAuthorizedPoolRequest],
*,
user: T,
) -> bool:
"""
Batch version of ``is_authorized_pool``.

By default, calls individually the ``is_authorized_pool`` API on each item in the list of
requests. Can lead to some poor performance. It is recommended to override this method in the auth
manager implementation to provide a more efficient implementation.

:param requests: a list of requests containing the parameters for ``is_authorized_pool``
:param user: the user to performing the action
"""
return all(
self.is_authorized_pool(method=request["method"], details=request.get("details"), user=user)
for request in requests
)

def batch_is_authorized_variable(
self,
requests: Sequence[IsAuthorizedVariableRequest],
*,
user: T,
) -> bool:
"""
Batch version of ``is_authorized_variable``.

By default, calls individually the ``is_authorized_variable`` API on each item in the list of
requests. Can lead to some poor performance. It is recommended to override this method in the auth
manager implementation to provide a more efficient implementation.

:param requests: a list of requests containing the parameters for ``is_authorized_variable``
:param user: the user to performing the action
"""
return all(
self.is_authorized_variable(method=request["method"], details=request.get("details"), user=user)
for request in requests
)

@provide_session
def get_authorized_dag_ids(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,38 +22,14 @@
if TYPE_CHECKING:
from airflow.api_fastapi.auth.managers.base_auth_manager import ResourceMethod
from airflow.api_fastapi.auth.managers.models.resource_details import (
ConnectionDetails,
DagAccessEntity,
DagDetails,
PoolDetails,
VariableDetails,
)


class IsAuthorizedConnectionRequest(TypedDict, total=False):
"""Represent the parameters of ``is_authorized_connection`` API in the auth manager."""

method: ResourceMethod
details: ConnectionDetails | None


class IsAuthorizedDagRequest(TypedDict, total=False):
"""Represent the parameters of ``is_authorized_dag`` API in the auth manager."""

method: ResourceMethod
access_entity: DagAccessEntity | None
details: DagDetails | None


class IsAuthorizedPoolRequest(TypedDict, total=False):
"""Represent the parameters of ``is_authorized_pool`` API in the auth manager."""

method: ResourceMethod
details: PoolDetails | None


class IsAuthorizedVariableRequest(TypedDict, total=False):
"""Represent the parameters of ``is_authorized_variable`` API in the auth manager."""

method: ResourceMethod
details: VariableDetails | None
Original file line number Diff line number Diff line change
Expand Up @@ -261,70 +261,6 @@ def test_batch_is_authorized_dag(self, mock_is_authorized_dag, auth_manager, ret
)
assert result == expected

@pytest.mark.parametrize(
"return_values, expected",
[
([False, False], False),
([True, False], False),
([True, True], True),
],
)
@patch.object(EmptyAuthManager, "is_authorized_connection")
def test_batch_is_authorized_connection(
self, mock_is_authorized_connection, auth_manager, return_values, expected
):
mock_is_authorized_connection.side_effect = return_values
result = auth_manager.batch_is_authorized_connection(
[
{"method": "GET", "details": ConnectionDetails(conn_id="conn1")},
{"method": "GET", "details": ConnectionDetails(conn_id="conn2")},
],
user=Mock(),
)
assert result == expected

@pytest.mark.parametrize(
"return_values, expected",
[
([False, False], False),
([True, False], False),
([True, True], True),
],
)
@patch.object(EmptyAuthManager, "is_authorized_pool")
def test_batch_is_authorized_pool(self, mock_is_authorized_pool, auth_manager, return_values, expected):
mock_is_authorized_pool.side_effect = return_values
result = auth_manager.batch_is_authorized_pool(
[
{"method": "GET", "details": PoolDetails(name="pool1")},
{"method": "GET", "details": PoolDetails(name="pool2")},
],
user=Mock(),
)
assert result == expected

@pytest.mark.parametrize(
"return_values, expected",
[
([False, False], False),
([True, False], False),
([True, True], True),
],
)
@patch.object(EmptyAuthManager, "is_authorized_variable")
def test_batch_is_authorized_variable(
self, mock_is_authorized_variable, auth_manager, return_values, expected
):
mock_is_authorized_variable.side_effect = return_values
result = auth_manager.batch_is_authorized_variable(
[
{"method": "GET", "details": VariableDetails(key="var1")},
{"method": "GET", "details": VariableDetails(key="var2")},
],
user=Mock(),
)
assert result == expected

@pytest.mark.parametrize(
"access_per_dag, dag_ids, expected",
[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,7 @@
if TYPE_CHECKING:
from airflow.api_fastapi.auth.managers.base_auth_manager import ResourceMethod
from airflow.api_fastapi.auth.managers.models.batch_apis import (
IsAuthorizedConnectionRequest,
IsAuthorizedDagRequest,
IsAuthorizedPoolRequest,
IsAuthorizedVariableRequest,
)
from airflow.api_fastapi.auth.managers.models.resource_details import (
AccessView,
Expand Down Expand Up @@ -247,24 +244,6 @@ def _has_access_to_menu_item(request: IsAuthorizedRequest):

return [menu_item for menu_item in menu_items if _has_access_to_menu_item(requests[menu_item.value])]

def batch_is_authorized_connection(
self,
requests: Sequence[IsAuthorizedConnectionRequest],
*,
user: AwsAuthManagerUser,
) -> bool:
facade_requests: Sequence[IsAuthorizedRequest] = [
{
"method": request["method"],
"entity_type": AvpEntities.CONNECTION,
"entity_id": cast("ConnectionDetails", request["details"]).conn_id
if request.get("details")
else None,
}
for request in requests
]
return self.avp_facade.batch_is_authorized(requests=facade_requests, user=user)

def batch_is_authorized_dag(
self,
requests: Sequence[IsAuthorizedDagRequest],
Expand All @@ -288,40 +267,6 @@ def batch_is_authorized_dag(
]
return self.avp_facade.batch_is_authorized(requests=facade_requests, user=user)

def batch_is_authorized_pool(
self,
requests: Sequence[IsAuthorizedPoolRequest],
*,
user: AwsAuthManagerUser,
) -> bool:
facade_requests: Sequence[IsAuthorizedRequest] = [
{
"method": request["method"],
"entity_type": AvpEntities.POOL,
"entity_id": cast("PoolDetails", request["details"]).name if request.get("details") else None,
}
for request in requests
]
return self.avp_facade.batch_is_authorized(requests=facade_requests, user=user)

def batch_is_authorized_variable(
self,
requests: Sequence[IsAuthorizedVariableRequest],
*,
user: AwsAuthManagerUser,
) -> bool:
facade_requests: Sequence[IsAuthorizedRequest] = [
{
"method": request["method"],
"entity_type": AvpEntities.VARIABLE,
"entity_id": cast("VariableDetails", request["details"]).key
if request.get("details")
else None,
}
for request in requests
]
return self.avp_facade.batch_is_authorized(requests=facade_requests, user=user)

def filter_authorized_dag_ids(
self,
*,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -368,37 +368,6 @@ def test_is_authorized_view(self, mock_avp_facade, access_view, user, expected_u
)
assert result

@patch.object(AwsAuthManager, "avp_facade")
def test_batch_is_authorized_connection(
self,
mock_avp_facade,
auth_manager,
):
batch_is_authorized = Mock(return_value=True)
mock_avp_facade.batch_is_authorized = batch_is_authorized

result = auth_manager.batch_is_authorized_connection(
requests=[{"method": "GET"}, {"method": "GET", "details": ConnectionDetails(conn_id="conn_id")}],
user=mock,
)

batch_is_authorized.assert_called_once_with(
requests=[
{
"method": "GET",
"entity_type": AvpEntities.CONNECTION,
"entity_id": None,
},
{
"method": "GET",
"entity_type": AvpEntities.CONNECTION,
"entity_id": "conn_id",
},
],
user=ANY,
)
assert result

def test_filter_authorized_menu_items(self, auth_manager):
batch_is_authorized_output = [
{
Expand Down Expand Up @@ -541,68 +510,6 @@ def test_batch_is_authorized_dag(
)
assert result

@patch.object(AwsAuthManager, "avp_facade")
def test_batch_is_authorized_pool(
self,
mock_avp_facade,
auth_manager,
):
batch_is_authorized = Mock(return_value=True)
mock_avp_facade.batch_is_authorized = batch_is_authorized

result = auth_manager.batch_is_authorized_pool(
requests=[{"method": "GET"}, {"method": "GET", "details": PoolDetails(name="pool1")}],
user=mock,
)

batch_is_authorized.assert_called_once_with(
requests=[
{
"method": "GET",
"entity_type": AvpEntities.POOL,
"entity_id": None,
},
{
"method": "GET",
"entity_type": AvpEntities.POOL,
"entity_id": "pool1",
},
],
user=ANY,
)
assert result

@patch.object(AwsAuthManager, "avp_facade")
def test_batch_is_authorized_variable(
self,
mock_avp_facade,
auth_manager,
):
batch_is_authorized = Mock(return_value=True)
mock_avp_facade.batch_is_authorized = batch_is_authorized

result = auth_manager.batch_is_authorized_variable(
requests=[{"method": "GET"}, {"method": "GET", "details": VariableDetails(key="var1")}],
user=mock,
)

batch_is_authorized.assert_called_once_with(
requests=[
{
"method": "GET",
"entity_type": AvpEntities.VARIABLE,
"entity_id": None,
},
{
"method": "GET",
"entity_type": AvpEntities.VARIABLE,
"entity_id": "var1",
},
],
user=ANY,
)
assert result

@pytest.mark.parametrize(
"method, user, expected_result",
[
Expand Down
Loading