Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,10 @@
if TYPE_CHECKING:
from airflow.api_fastapi.auth.managers.base_auth_manager import ResourceMethod
from airflow.api_fastapi.auth.managers.models.batch_apis import (
IsAuthorizedConnectionRequest,
IsAuthorizedDagRequest,
IsAuthorizedPoolRequest,
IsAuthorizedVariableRequest,
)
from airflow.api_fastapi.auth.managers.models.resource_details import (
AccessView,
Expand Down Expand Up @@ -244,25 +247,93 @@ def _has_access_to_menu_item(request: IsAuthorizedRequest):

return [menu_item for menu_item in menu_items if _has_access_to_menu_item(requests[menu_item.value])]

def batch_is_authorized_connection(
self,
requests: Sequence[IsAuthorizedConnectionRequest],
*,
user: AwsAuthManagerUser,
) -> bool:
facade_requests: Sequence[IsAuthorizedRequest] = [
cast(
"IsAuthorizedRequest",
{
"method": request["method"],
"entity_type": AvpEntities.CONNECTION,
"entity_id": cast("ConnectionDetails", request["details"]).conn_id
if request.get("details")
else None,
},
)
for request in requests
]
return self.avp_facade.batch_is_authorized(requests=facade_requests, user=user)

def batch_is_authorized_dag(
self,
requests: Sequence[IsAuthorizedDagRequest],
*,
user: AwsAuthManagerUser,
) -> bool:
facade_requests: Sequence[IsAuthorizedRequest] = [
{
"method": request["method"],
"entity_type": AvpEntities.DAG,
"entity_id": cast("DagDetails", request["details"]).id if request.get("details") else None,
"context": {
"dag_entity": {
"string": cast("DagAccessEntity", request["access_entity"]).value,
},
}
if request.get("access_entity")
else None,
}
cast(
"IsAuthorizedRequest",
{
"method": request["method"],
"entity_type": AvpEntities.DAG,
"entity_id": cast("DagDetails", request["details"]).id
if request.get("details")
else None,
"context": {
"dag_entity": {
"string": cast("DagAccessEntity", request["access_entity"]).value,
},
}
if request.get("access_entity")
else None,
},
)
for request in requests
]
return self.avp_facade.batch_is_authorized(requests=facade_requests, user=user)

def batch_is_authorized_pool(
self,
requests: Sequence[IsAuthorizedPoolRequest],
*,
user: AwsAuthManagerUser,
) -> bool:
facade_requests: Sequence[IsAuthorizedRequest] = [
cast(
"IsAuthorizedRequest",
{
"method": request["method"],
"entity_type": AvpEntities.POOL,
"entity_id": cast("PoolDetails", request["details"]).name
if request.get("details")
else None,
},
)
for request in requests
]
return self.avp_facade.batch_is_authorized(requests=facade_requests, user=user)

def batch_is_authorized_variable(
self,
requests: Sequence[IsAuthorizedVariableRequest],
*,
user: AwsAuthManagerUser,
) -> bool:
facade_requests: Sequence[IsAuthorizedRequest] = [
cast(
"IsAuthorizedRequest",
{
"method": request["method"],
"entity_type": AvpEntities.VARIABLE,
"entity_id": cast("VariableDetails", request["details"]).key
if request.get("details")
else None,
},
)
for request in requests
]
return self.avp_facade.batch_is_authorized(requests=facade_requests, user=user)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,40 @@ def test_filter_authorized_menu_items(self, auth_manager):
)
assert result == [MenuItem.VARIABLES, MenuItem.DAGS]

@patch.object(AwsAuthManager, "avp_facade")
def test_batch_is_authorized_connection(
self,
mock_avp_facade,
auth_manager,
):
batch_is_authorized = Mock(return_value=True)
mock_avp_facade.batch_is_authorized = batch_is_authorized

result = auth_manager.batch_is_authorized_connection(
requests=[
{"method": "GET"},
{"method": "PUT", "details": ConnectionDetails(conn_id="test")},
],
user=mock,
)

batch_is_authorized.assert_called_once_with(
requests=[
{
"method": "GET",
"entity_type": AvpEntities.CONNECTION,
"entity_id": None,
},
{
"method": "PUT",
"entity_type": AvpEntities.CONNECTION,
"entity_id": "test",
},
],
user=ANY,
)
assert result

@patch.object(AwsAuthManager, "avp_facade")
def test_batch_is_authorized_dag(
self,
Expand Down Expand Up @@ -510,6 +544,74 @@ def test_batch_is_authorized_dag(
)
assert result

@patch.object(AwsAuthManager, "avp_facade")
def test_batch_is_authorized_pool(
self,
mock_avp_facade,
auth_manager,
):
batch_is_authorized = Mock(return_value=True)
mock_avp_facade.batch_is_authorized = batch_is_authorized

result = auth_manager.batch_is_authorized_pool(
requests=[
{"method": "GET"},
{"method": "PUT", "details": PoolDetails(name="test")},
],
user=mock,
)

batch_is_authorized.assert_called_once_with(
requests=[
{
"method": "GET",
"entity_type": AvpEntities.POOL,
"entity_id": None,
},
{
"method": "PUT",
"entity_type": AvpEntities.POOL,
"entity_id": "test",
},
],
user=ANY,
)
assert result

@patch.object(AwsAuthManager, "avp_facade")
def test_batch_is_authorized_variable(
self,
mock_avp_facade,
auth_manager,
):
batch_is_authorized = Mock(return_value=True)
mock_avp_facade.batch_is_authorized = batch_is_authorized

result = auth_manager.batch_is_authorized_variable(
requests=[
{"method": "GET"},
{"method": "PUT", "details": VariableDetails(key="test")},
],
user=mock,
)

batch_is_authorized.assert_called_once_with(
requests=[
{
"method": "GET",
"entity_type": AvpEntities.VARIABLE,
"entity_id": None,
},
{
"method": "PUT",
"entity_type": AvpEntities.VARIABLE,
"entity_id": "test",
},
],
user=ANY,
)
assert result

@pytest.mark.parametrize(
"method, user, expected_result",
[
Expand Down