Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,37 @@ def batch_is_authorized_variable(
]
return self.avp_facade.batch_is_authorized(requests=facade_requests, user=user)

def filter_authorized_connections(
self,
*,
conn_ids: set[str],
user: AwsAuthManagerUser,
method: ResourceMethod = "GET",
team_name: str | None = None,
) -> set[str]:
requests: dict[str, dict[ResourceMethod, IsAuthorizedRequest]] = defaultdict(dict)
requests_list: list[IsAuthorizedRequest] = []
for conn_id in conn_ids:
request: IsAuthorizedRequest = {
"method": method,
"entity_type": AvpEntities.CONNECTION,
"entity_id": conn_id,
}
requests[conn_id][method] = request
requests_list.append(request)

batch_is_authorized_results = self.avp_facade.get_batch_is_authorized_results(
requests=requests_list, user=user
)

return {
conn_id
for conn_id in conn_ids
if self._is_authorized_from_batch_response(
batch_is_authorized_results, requests[conn_id][method], user
)
}

def filter_authorized_dag_ids(
self,
*,
Expand All @@ -361,13 +392,75 @@ def filter_authorized_dag_ids(
requests=requests_list, user=user
)

def _has_access_to_dag(request: IsAuthorizedRequest):
result = self.avp_facade.get_batch_is_authorized_single_result(
batch_is_authorized_results=batch_is_authorized_results, request=request, user=user
return {
dag_id
for dag_id in dag_ids
if self._is_authorized_from_batch_response(
batch_is_authorized_results, requests[dag_id][method], user
)
return result["decision"] == "ALLOW"
}

return {dag_id for dag_id in dag_ids if _has_access_to_dag(requests[dag_id][method])}
def filter_authorized_pools(
self,
*,
pool_names: set[str],
user: AwsAuthManagerUser,
method: ResourceMethod = "GET",
team_name: str | None = None,
) -> set[str]:
requests: dict[str, dict[ResourceMethod, IsAuthorizedRequest]] = defaultdict(dict)
requests_list: list[IsAuthorizedRequest] = []
for pool_name in pool_names:
request: IsAuthorizedRequest = {
"method": method,
"entity_type": AvpEntities.POOL,
"entity_id": pool_name,
}
requests[pool_name][method] = request
requests_list.append(request)

batch_is_authorized_results = self.avp_facade.get_batch_is_authorized_results(
requests=requests_list, user=user
)

return {
pool_name
for pool_name in pool_names
if self._is_authorized_from_batch_response(
batch_is_authorized_results, requests[pool_name][method], user
)
}

def filter_authorized_variables(
self,
*,
variable_keys: set[str],
user: AwsAuthManagerUser,
method: ResourceMethod = "GET",
team_name: str | None = None,
) -> set[str]:
requests: dict[str, dict[ResourceMethod, IsAuthorizedRequest]] = defaultdict(dict)
requests_list: list[IsAuthorizedRequest] = []
for variable_key in variable_keys:
request: IsAuthorizedRequest = {
"method": method,
"entity_type": AvpEntities.VARIABLE,
"entity_id": variable_key,
}
requests[variable_key][method] = request
requests_list.append(request)

batch_is_authorized_results = self.avp_facade.get_batch_is_authorized_results(
requests=requests_list, user=user
)

return {
variable_key
for variable_key in variable_keys
if self._is_authorized_from_batch_response(
batch_is_authorized_results, requests[variable_key][method], user
)
}

def get_url_login(self, **kwargs) -> str:
return urljoin(self.apiserver_endpoint, f"{AUTH_MANAGER_FASTAPI_APP_PREFIX}/login")
Expand Down Expand Up @@ -406,6 +499,14 @@ def _get_menu_item_request(menu_item_text: str) -> IsAuthorizedRequest:
"entity_id": menu_item_text,
}

def _is_authorized_from_batch_response(
self, batch_is_authorized_results: list[dict], request: IsAuthorizedRequest, user: AwsAuthManagerUser
):
result = self.avp_facade.get_batch_is_authorized_single_result(
batch_is_authorized_results=batch_is_authorized_results, request=request, user=user
)
return result["decision"] == "ALLOW"

def _check_avp_schema_version(self):
if not self.avp_facade.is_policy_store_schema_up_to_date():
self.log.warning(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -612,81 +612,100 @@ def test_batch_is_authorized_variable(
)
assert result

@pytest.mark.parametrize(
"get_authorized_method, avp_entity, entities_parameter",
[
("filter_authorized_connections", AvpEntities.CONNECTION.value, "conn_ids"),
("filter_authorized_dag_ids", AvpEntities.DAG.value, "dag_ids"),
("filter_authorized_pools", AvpEntities.POOL.value, "pool_names"),
("filter_authorized_variables", AvpEntities.VARIABLE.value, "variable_keys"),
],
)
@pytest.mark.parametrize(
"method, user, expected_result",
[
("GET", AwsAuthManagerUser(user_id="test_user_id1", groups=[]), {"dag_1"}),
("GET", AwsAuthManagerUser(user_id="test_user_id1", groups=[]), {"entity_1"}),
("PUT", AwsAuthManagerUser(user_id="test_user_id1", groups=[]), set()),
("GET", AwsAuthManagerUser(user_id="test_user_id2", groups=[]), set()),
("PUT", AwsAuthManagerUser(user_id="test_user_id2", groups=[]), {"dag_2"}),
("PUT", AwsAuthManagerUser(user_id="test_user_id2", groups=[]), {"entity_2"}),
],
)
def test_filter_authorized_dag_ids(self, method, user, auth_manager, test_user, expected_result):
dag_ids = {"dag_1", "dag_2"}
# test_user_id1 has GET permissions on dag_1
# test_user_id2 has PUT permissions on dag_2
def test_filter_authorized(
self,
get_authorized_method,
avp_entity,
entities_parameter,
method,
user,
auth_manager,
test_user,
expected_result,
):
entity_ids = {"entity_1", "entity_2"}
# test_user_id1 has GET permissions on entity_1
# test_user_id2 has PUT permissions on entity_2
batch_is_authorized_output = [
{
"request": {
"principal": {"entityType": "Airflow::User", "entityId": "test_user_id1"},
"action": {"actionType": "Airflow::Action", "actionId": "Dag.GET"},
"resource": {"entityType": "Airflow::Dag", "entityId": "dag_1"},
"action": {"actionType": "Airflow::Action", "actionId": f"{avp_entity}.GET"},
"resource": {"entityType": f"Airflow::{avp_entity}", "entityId": "entity_1"},
},
"decision": "ALLOW",
},
{
"request": {
"principal": {"entityType": "Airflow::User", "entityId": "test_user_id1"},
"action": {"actionType": "Airflow::Action", "actionId": "Dag.PUT"},
"resource": {"entityType": "Airflow::Dag", "entityId": "dag_1"},
"action": {"actionType": "Airflow::Action", "actionId": f"{avp_entity}.PUT"},
"resource": {"entityType": f"Airflow::{avp_entity}", "entityId": "entity_1"},
},
"decision": "DENY",
},
{
"request": {
"principal": {"entityType": "Airflow::User", "entityId": "test_user_id1"},
"action": {"actionType": "Airflow::Action", "actionId": "Dag.GET"},
"resource": {"entityType": "Airflow::Dag", "entityId": "dag_2"},
"action": {"actionType": "Airflow::Action", "actionId": f"{avp_entity}.GET"},
"resource": {"entityType": f"Airflow::{avp_entity}", "entityId": "entity_2"},
},
"decision": "DENY",
},
{
"request": {
"principal": {"entityType": "Airflow::User", "entityId": "test_user_id1"},
"action": {"actionType": "Airflow::Action", "actionId": "Dag.PUT"},
"resource": {"entityType": "Airflow::Dag", "entityId": "dag_2"},
"action": {"actionType": "Airflow::Action", "actionId": f"{avp_entity}.PUT"},
"resource": {"entityType": f"Airflow::{avp_entity}", "entityId": "entity_2"},
},
"decision": "DENY",
},
{
"request": {
"principal": {"entityType": "Airflow::User", "entityId": "test_user_id2"},
"action": {"actionType": "Airflow::Action", "actionId": "Dag.GET"},
"resource": {"entityType": "Airflow::Dag", "entityId": "dag_1"},
"action": {"actionType": "Airflow::Action", "actionId": f"{avp_entity}.GET"},
"resource": {"entityType": f"Airflow::{avp_entity}", "entityId": "entity_1"},
},
"decision": "DENY",
},
{
"request": {
"principal": {"entityType": "Airflow::User", "entityId": "test_user_id2"},
"action": {"actionType": "Airflow::Action", "actionId": "Dag.PUT"},
"resource": {"entityType": "Airflow::Dag", "entityId": "dag_1"},
"action": {"actionType": "Airflow::Action", "actionId": f"{avp_entity}.PUT"},
"resource": {"entityType": f"Airflow::{avp_entity}", "entityId": "entity_1"},
},
"decision": "DENY",
},
{
"request": {
"principal": {"entityType": "Airflow::User", "entityId": "test_user_id2"},
"action": {"actionType": "Airflow::Action", "actionId": "Dag.GET"},
"resource": {"entityType": "Airflow::Dag", "entityId": "dag_2"},
"action": {"actionType": "Airflow::Action", "actionId": f"{avp_entity}.GET"},
"resource": {"entityType": f"Airflow::{avp_entity}", "entityId": "entity_2"},
},
"decision": "DENY",
},
{
"request": {
"principal": {"entityType": "Airflow::User", "entityId": "test_user_id2"},
"action": {"actionType": "Airflow::Action", "actionId": "Dag.PUT"},
"resource": {"entityType": "Airflow::Dag", "entityId": "dag_2"},
"action": {"actionType": "Airflow::Action", "actionId": f"{avp_entity}.PUT"},
"resource": {"entityType": f"Airflow::{avp_entity}", "entityId": "entity_2"},
},
"decision": "ALLOW",
},
Expand All @@ -695,11 +714,12 @@ def test_filter_authorized_dag_ids(self, method, user, auth_manager, test_user,
return_value=batch_is_authorized_output
)

result = auth_manager.filter_authorized_dag_ids(
dag_ids=dag_ids,
method=method,
user=user,
)
params = {
entities_parameter: entity_ids,
"method": method,
"user": user,
}
result = getattr(auth_manager, get_authorized_method)(**params)

auth_manager.avp_facade.get_batch_is_authorized_results.assert_called()
assert result == expected_result
Expand Down