diff --git a/providers/keycloak/src/airflow/providers/keycloak/auth_manager/keycloak_auth_manager.py b/providers/keycloak/src/airflow/providers/keycloak/auth_manager/keycloak_auth_manager.py index d4416ab78ef5e..0d7a415453e16 100644 --- a/providers/keycloak/src/airflow/providers/keycloak/auth_manager/keycloak_auth_manager.py +++ b/providers/keycloak/src/airflow/providers/keycloak/auth_manager/keycloak_auth_manager.py @@ -25,11 +25,12 @@ from fastapi import FastAPI from airflow.api_fastapi.app import AUTH_MANAGER_FASTAPI_APP_PREFIX -from airflow.api_fastapi.auth.managers.base_auth_manager import BaseAuthManager, T +from airflow.api_fastapi.auth.managers.base_auth_manager import BaseAuthManager from airflow.configuration import conf from airflow.exceptions import AirflowException from airflow.providers.keycloak.auth_manager.resources import KeycloakResource from airflow.providers.keycloak.auth_manager.user import KeycloakAuthManagerUser +from airflow.utils.helpers import prune_dict if TYPE_CHECKING: from airflow.api_fastapi.auth.managers.base_auth_manager import ResourceMethod @@ -49,6 +50,8 @@ log = logging.getLogger(__name__) +RESOURCE_ID_ATTRIBUTE_NAME = "resource_id" + class KeycloakAuthManager(BaseAuthManager[KeycloakAuthManagerUser]): """ @@ -86,15 +89,10 @@ def is_authorized_configuration( ) -> bool: config_section = details.section if details else None return self._is_authorized( - method=method, resource_type=KeycloakResource.CONFIGURATION, user=user - ) or ( - config_section is not None - and self._is_authorized( - method=method, - resource_type=KeycloakResource.CONFIGURATION, - user=user, - resource_id=config_section, - ) + method=method, + resource_type=KeycloakResource.CONFIGURATION, + user=user, + resource_id=config_section, ) def is_authorized_connection( @@ -105,43 +103,42 @@ def is_authorized_connection( details: ConnectionDetails | None = None, ) -> bool: connection_id = details.conn_id if details else None - return self._is_authorized(method=method, resource_type=KeycloakResource.CONNECTION, user=user) or ( - connection_id is not None - and self._is_authorized( - method=method, resource_type=KeycloakResource.CONNECTION, user=user, resource_id=connection_id - ) + return self._is_authorized( + method=method, resource_type=KeycloakResource.CONNECTION, user=user, resource_id=connection_id ) def is_authorized_dag( self, *, method: ResourceMethod, - user: T, + user: KeycloakAuthManagerUser, access_entity: DagAccessEntity | None = None, details: DagDetails | None = None, ) -> bool: - return True + dag_id = details.id if details else None + access_entity_str = access_entity.value if access_entity else None + return self._is_authorized( + method=method, + resource_type=KeycloakResource.DAG, + user=user, + resource_id=dag_id, + attributes={"dag_entity": access_entity_str}, + ) def is_authorized_backfill( self, *, method: ResourceMethod, user: KeycloakAuthManagerUser, details: BackfillDetails | None = None ) -> bool: backfill_id = str(details.id) if details else None - return self._is_authorized(method=method, resource_type=KeycloakResource.BACKFILL, user=user) or ( - backfill_id is not None - and self._is_authorized( - method=method, resource_type=KeycloakResource.BACKFILL, user=user, resource_id=backfill_id - ) + return self._is_authorized( + method=method, resource_type=KeycloakResource.BACKFILL, user=user, resource_id=backfill_id ) def is_authorized_asset( self, *, method: ResourceMethod, user: KeycloakAuthManagerUser, details: AssetDetails | None = None ) -> bool: asset_id = details.id if details else None - return self._is_authorized(method=method, resource_type=KeycloakResource.ASSET, user=user) or ( - asset_id is not None - and self._is_authorized( - method=method, resource_type=KeycloakResource.ASSET, user=user, resource_id=asset_id - ) + return self._is_authorized( + method=method, resource_type=KeycloakResource.ASSET, user=user, resource_id=asset_id ) def is_authorized_asset_alias( @@ -152,42 +149,31 @@ def is_authorized_asset_alias( details: AssetAliasDetails | None = None, ) -> bool: asset_alias_id = details.id if details else None - return self._is_authorized(method=method, resource_type=KeycloakResource.ASSET_ALIAS, user=user) or ( - asset_alias_id is not None - and self._is_authorized( - method=method, - resource_type=KeycloakResource.ASSET_ALIAS, - user=user, - resource_id=asset_alias_id, - ) + return self._is_authorized( + method=method, + resource_type=KeycloakResource.ASSET_ALIAS, + user=user, + resource_id=asset_alias_id, ) def is_authorized_variable( self, *, method: ResourceMethod, user: KeycloakAuthManagerUser, details: VariableDetails | None = None ) -> bool: variable_key = details.key if details else None - return self._is_authorized(method=method, resource_type=KeycloakResource.VARIABLE, user=user) or ( - variable_key is not None - and self._is_authorized( - method=method, resource_type=KeycloakResource.VARIABLE, user=user, resource_id=variable_key - ) + return self._is_authorized( + method=method, resource_type=KeycloakResource.VARIABLE, user=user, resource_id=variable_key ) def is_authorized_pool( self, *, method: ResourceMethod, user: KeycloakAuthManagerUser, details: PoolDetails | None = None ) -> bool: pool_name = details.name if details else None - return self._is_authorized(method=method, resource_type=KeycloakResource.POOL, user=user) or ( - pool_name is not None - and self._is_authorized( - method=method, resource_type=KeycloakResource.POOL, user=user, resource_id=pool_name - ) + return self._is_authorized( + method=method, resource_type=KeycloakResource.POOL, user=user, resource_id=pool_name ) def is_authorized_view(self, *, access_view: AccessView, user: KeycloakAuthManagerUser) -> bool: return self._is_authorized( - method="GET", resource_type=KeycloakResource.VIEW, user=user - ) or self._is_authorized( method="GET", resource_type=KeycloakResource.VIEW, user=user, @@ -198,8 +184,6 @@ def is_authorized_custom_view( self, *, method: ResourceMethod | str, resource_name: str, user: KeycloakAuthManagerUser ) -> bool: return self._is_authorized( - method=method, resource_type=KeycloakResource.CUSTOM, user=user - ) or self._is_authorized( method=method, resource_type=KeycloakResource.CUSTOM, user=user, resource_id=resource_name ) @@ -230,19 +214,19 @@ def _is_authorized( resource_type: KeycloakResource, user: KeycloakAuthManagerUser, resource_id: str | None = None, + attributes: dict[str, str | None] | None = None, ) -> bool: client_id = conf.get("keycloak_auth_manager", "client_id") realm = conf.get("keycloak_auth_manager", "realm") server_url = conf.get("keycloak_auth_manager", "server_url") - permission = ( - f"{resource_type.value}:{resource_id}#{method}" - if resource_id - else f"{resource_type.value}#{method}" - ) + context_attributes = prune_dict(attributes or {}) + if resource_id: + context_attributes[RESOURCE_ID_ATTRIBUTE_NAME] = resource_id + resp = requests.post( self._get_token_url(server_url, realm), - data=self._get_payload(client_id, permission), + data=self._get_payload(client_id, f"{resource_type.value}#{method}", context_attributes), headers=self._get_headers(user.access_token), ) @@ -252,9 +236,6 @@ def _is_authorized( return False if resp.status_code == 400: error = json.loads(resp.text) - if error.get("error") == "invalid_resource": - log.debug(error["error_description"]) - return False raise AirflowException( f"Request not recognized by Keycloak. {error.get('error')}. {error.get('error_description')}" ) @@ -265,12 +246,16 @@ def _get_token_url(server_url, realm): return f"{server_url}/realms/{realm}/protocol/openid-connect/token" @staticmethod - def _get_payload(client_id, permission): - return { + def _get_payload(client_id: str, permission: str, attributes: dict[str, str] | None = None): + payload: dict[str, Any] = { "grant_type": "urn:ietf:params:oauth:grant-type:uma-ticket", "audience": client_id, "permission": permission, } + if attributes: + payload["context"] = {"attributes": attributes} + + return payload @staticmethod def _get_headers(access_token): diff --git a/providers/keycloak/tests/unit/keycloak/auth_manager/test_keycloak_auth_manager.py b/providers/keycloak/tests/unit/keycloak/auth_manager/test_keycloak_auth_manager.py index e397044b86bcd..1cf2c3343d728 100644 --- a/providers/keycloak/tests/unit/keycloak/auth_manager/test_keycloak_auth_manager.py +++ b/providers/keycloak/tests/unit/keycloak/auth_manager/test_keycloak_auth_manager.py @@ -17,7 +17,7 @@ from __future__ import annotations import json -from unittest.mock import Mock, call, patch +from unittest.mock import Mock, patch import pytest @@ -29,11 +29,16 @@ BackfillDetails, ConfigurationDetails, ConnectionDetails, + DagAccessEntity, + DagDetails, PoolDetails, VariableDetails, ) from airflow.exceptions import AirflowException -from airflow.providers.keycloak.auth_manager.keycloak_auth_manager import KeycloakAuthManager +from airflow.providers.keycloak.auth_manager.keycloak_auth_manager import ( + RESOURCE_ID_ATTRIBUTE_NAME, + KeycloakAuthManager, +) from airflow.providers.keycloak.auth_manager.user import KeycloakAuthManagerUser from tests_common.test_utils.config import conf_vars @@ -91,14 +96,14 @@ def test_get_url_login(self, auth_manager): assert result == f"{AUTH_MANAGER_FASTAPI_APP_PREFIX}/login" @pytest.mark.parametrize( - "function, method, details, permission_type, permission_resource", + "function, method, details, permission, attributes", [ [ "is_authorized_configuration", "GET", ConfigurationDetails(section="test"), "Configuration#GET", - "Configuration:test#GET", + {RESOURCE_ID_ATTRIBUTE_NAME: "test"}, ], ["is_authorized_configuration", "GET", None, "Configuration#GET", None], [ @@ -106,47 +111,63 @@ def test_get_url_login(self, auth_manager): "PUT", ConfigurationDetails(section="test"), "Configuration#PUT", - "Configuration:test#PUT", + {RESOURCE_ID_ATTRIBUTE_NAME: "test"}, ], [ "is_authorized_connection", "DELETE", ConnectionDetails(conn_id="test"), "Connection#DELETE", - "Connection:test#DELETE", + {RESOURCE_ID_ATTRIBUTE_NAME: "test"}, + ], + ["is_authorized_connection", "GET", None, "Connection#GET", {}], + [ + "is_authorized_backfill", + "POST", + BackfillDetails(id=1), + "Backfill#POST", + {RESOURCE_ID_ATTRIBUTE_NAME: "1"}, ], - ["is_authorized_connection", "GET", None, "Connection#GET", None], - ["is_authorized_backfill", "POST", BackfillDetails(id=1), "Backfill#POST", "Backfill:1#POST"], - ["is_authorized_backfill", "GET", None, "Backfill#GET", None], - ["is_authorized_asset", "GET", AssetDetails(id="test"), "Asset#GET", "Asset:test#GET"], - ["is_authorized_asset", "GET", None, "Asset#GET", None], + ["is_authorized_backfill", "GET", None, "Backfill#GET", {}], + [ + "is_authorized_asset", + "GET", + AssetDetails(id="test"), + "Asset#GET", + {RESOURCE_ID_ATTRIBUTE_NAME: "test"}, + ], + ["is_authorized_asset", "GET", None, "Asset#GET", {}], [ "is_authorized_asset_alias", "GET", AssetAliasDetails(id="test"), "AssetAlias#GET", - "AssetAlias:test#GET", + {RESOURCE_ID_ATTRIBUTE_NAME: "test"}, ], - ["is_authorized_asset_alias", "GET", None, "AssetAlias#GET", None], + ["is_authorized_asset_alias", "GET", None, "AssetAlias#GET", {}], [ "is_authorized_variable", "PUT", VariableDetails(key="test"), "Variable#PUT", - "Variable:test#PUT", + {RESOURCE_ID_ATTRIBUTE_NAME: "test"}, + ], + ["is_authorized_variable", "GET", None, "Variable#GET", {}], + [ + "is_authorized_pool", + "POST", + PoolDetails(name="test"), + "Pool#POST", + {RESOURCE_ID_ATTRIBUTE_NAME: "test"}, ], - ["is_authorized_variable", "GET", None, "Variable#GET", None], - ["is_authorized_pool", "POST", PoolDetails(name="test"), "Pool#POST", "Pool:test#POST"], - ["is_authorized_pool", "GET", None, "Pool#GET", None], + ["is_authorized_pool", "GET", None, "Pool#GET", {}], ], ) @pytest.mark.parametrize( - "status_code_type, status_code_resource, expected_one_call, expected_two_calls", + "status_code, expected", [ - [200, 200, True, True], - [200, 403, True, True], - [403, 200, False, True], - [403, 403, False, False], + [200, True], + [403, False], ], ) @patch("airflow.providers.keycloak.auth_manager.keycloak_auth_manager.requests") @@ -156,35 +177,21 @@ def test_is_authorized( function, method, details, - permission_type, - permission_resource, - status_code_type, - status_code_resource, - expected_one_call, - expected_two_calls, + permission, + attributes, + status_code, + expected, auth_manager, user, ): - expected_num_calls = 1 if status_code_type == 200 or not permission_resource else 2 - resp1 = Mock() - resp1.status_code = status_code_type - resp2 = Mock() - resp2.status_code = status_code_resource - - mock_requests.post.side_effect = [resp1, resp2] + mock_requests.post.return_value.status_code = status_code result = getattr(auth_manager, function)(method=method, user=user, details=details) token_url = auth_manager._get_token_url("server_url", "realm") - payload = auth_manager._get_payload("client_id", permission_type) + payload = auth_manager._get_payload("client_id", permission, attributes) headers = auth_manager._get_headers("access_token") - expected_calls = [call(token_url, data=payload, headers=headers)] - if expected_num_calls == 2: - expected_calls.append( - call(token_url, data={**payload, "permission": permission_resource}, headers=headers) - ) - mock_requests.post.assert_has_calls(expected_calls) - expected = expected_two_calls if expected_num_calls == 2 else expected_one_call + mock_requests.post.assert_called_once_with(token_url, data=payload, headers=headers) assert result == expected @pytest.mark.parametrize( @@ -192,6 +199,7 @@ def test_is_authorized( [ "is_authorized_configuration", "is_authorized_connection", + "is_authorized_dag", "is_authorized_backfill", "is_authorized_asset", "is_authorized_asset_alias", @@ -215,6 +223,7 @@ def test_is_authorized_failure(self, mock_requests, function, auth_manager, user [ "is_authorized_configuration", "is_authorized_connection", + "is_authorized_dag", "is_authorized_backfill", "is_authorized_asset", "is_authorized_asset_alias", @@ -223,123 +232,133 @@ def test_is_authorized_failure(self, mock_requests, function, auth_manager, user ], ) @patch("airflow.providers.keycloak.auth_manager.keycloak_auth_manager.requests") - def test_is_authorized_invalid_resource(self, mock_requests, function, auth_manager, user): + def test_is_authorized_invalid_request(self, mock_requests, function, auth_manager, user): resp = Mock() resp.status_code = 400 - resp.text = json.dumps( - {"error": "invalid_resource", "error_description": "Resource with id [Pool] does not exist."} - ) + resp.text = json.dumps({"error": "invalid_scope", "error_description": "Invalid scopes: GET"}) mock_requests.post.return_value = resp - result = getattr(auth_manager, function)(method="GET", user=user) - assert result is False + with pytest.raises(AirflowException) as e: + getattr(auth_manager, function)(method="GET", user=user) + + assert "Request not recognized by Keycloak. invalid_scope. Invalid scopes: GET" in str(e.value) @pytest.mark.parametrize( - "function", + "method, access_entity, details, permission, attributes", [ - "is_authorized_configuration", - "is_authorized_connection", - "is_authorized_backfill", - "is_authorized_asset", - "is_authorized_asset_alias", - "is_authorized_variable", - "is_authorized_pool", + [ + "GET", + None, + None, + "Dag#GET", + {}, + ], + [ + "GET", + DagAccessEntity.TASK_INSTANCE, + DagDetails(id="test"), + "Dag#GET", + {RESOURCE_ID_ATTRIBUTE_NAME: "test", "dag_entity": "TASK_INSTANCE"}, + ], + [ + "GET", + None, + DagDetails(id="test"), + "Dag#GET", + {RESOURCE_ID_ATTRIBUTE_NAME: "test"}, + ], + [ + "GET", + DagAccessEntity.TASK_INSTANCE, + None, + "Dag#GET", + {"dag_entity": "TASK_INSTANCE"}, + ], + ], + ) + @pytest.mark.parametrize( + "status_code, expected", + [ + [200, True], + [403, False], ], ) @patch("airflow.providers.keycloak.auth_manager.keycloak_auth_manager.requests") - def test_is_authorized_invalid_request(self, mock_requests, function, auth_manager, user): - resp = Mock() - resp.status_code = 400 - resp.text = json.dumps({"error": "invalid_scope", "error_description": "Invalid scopes: GET"}) - mock_requests.post.return_value = resp + def test_is_authorized_dag( + self, + mock_requests, + method, + access_entity, + details, + permission, + attributes, + status_code, + expected, + auth_manager, + user, + ): + mock_requests.post.return_value.status_code = status_code - with pytest.raises(AirflowException) as e: - getattr(auth_manager, function)(method="GET", user=user) + result = auth_manager.is_authorized_dag( + method=method, user=user, access_entity=access_entity, details=details + ) - assert "Request not recognized by Keycloak. invalid_scope. Invalid scopes: GET" in str(e.value) + token_url = auth_manager._get_token_url("server_url", "realm") + payload = auth_manager._get_payload("client_id", permission, attributes) + headers = auth_manager._get_headers("access_token") + mock_requests.post.assert_called_once_with(token_url, data=payload, headers=headers) + assert result == expected @pytest.mark.parametrize( - "status_code_type, status_code_resource, expected_one_call, expected_two_calls", + "status_code, expected", [ - [200, 200, True, True], - [200, 403, True, True], - [403, 200, False, True], - [403, 403, False, False], + [200, True], + [403, False], ], ) @patch("airflow.providers.keycloak.auth_manager.keycloak_auth_manager.requests") def test_is_authorized_view( self, mock_requests, - status_code_type, - status_code_resource, - expected_one_call, - expected_two_calls, + status_code, + expected, auth_manager, user, ): - expected_num_calls = 1 if status_code_type == 200 else 2 - resp1 = Mock() - resp1.status_code = status_code_type - resp2 = Mock() - resp2.status_code = status_code_resource - - mock_requests.post.side_effect = [resp1, resp2] + mock_requests.post.return_value.status_code = status_code result = auth_manager.is_authorized_view(access_view=AccessView.CLUSTER_ACTIVITY, user=user) token_url = auth_manager._get_token_url("server_url", "realm") - payload = auth_manager._get_payload("client_id", "View#GET") + payload = auth_manager._get_payload( + "client_id", "View#GET", {RESOURCE_ID_ATTRIBUTE_NAME: "CLUSTER_ACTIVITY"} + ) headers = auth_manager._get_headers("access_token") - expected_calls = [call(token_url, data=payload, headers=headers)] - if expected_num_calls == 2: - expected_calls.append( - call(token_url, data={**payload, "permission": "View:CLUSTER_ACTIVITY#GET"}, headers=headers) - ) - - mock_requests.post.assert_has_calls(expected_calls) - expected = expected_two_calls if expected_num_calls == 2 else expected_one_call + mock_requests.post.assert_called_once_with(token_url, data=payload, headers=headers) assert result == expected @pytest.mark.parametrize( - "status_code_type, status_code_resource, expected_one_call, expected_two_calls", + "status_code, expected", [ - [200, 200, True, True], - [200, 403, True, True], - [403, 200, False, True], - [403, 403, False, False], + [200, True], + [403, False], ], ) @patch("airflow.providers.keycloak.auth_manager.keycloak_auth_manager.requests") def test_is_authorized_custom_view( self, mock_requests, - status_code_type, - status_code_resource, - expected_one_call, - expected_two_calls, + status_code, + expected, auth_manager, user, ): - expected_num_calls = 1 if status_code_type == 200 else 2 - resp1 = Mock() - resp1.status_code = status_code_type - resp2 = Mock() - resp2.status_code = status_code_resource - - mock_requests.post.side_effect = [resp1, resp2] + mock_requests.post.return_value.status_code = status_code result = auth_manager.is_authorized_custom_view(method="GET", resource_name="test", user=user) token_url = auth_manager._get_token_url("server_url", "realm") - payload = auth_manager._get_payload("client_id", "Custom#GET") + payload = auth_manager._get_payload("client_id", "Custom#GET", {RESOURCE_ID_ATTRIBUTE_NAME: "test"}) headers = auth_manager._get_headers("access_token") - expected_calls = [call(token_url, data=payload, headers=headers)] - if expected_num_calls == 2: - expected_calls.append( - call(token_url, data={**payload, "permission": "Custom:test#GET"}, headers=headers) - ) - - mock_requests.post.assert_has_calls(expected_calls) - expected = expected_two_calls if expected_num_calls == 2 else expected_one_call + mock_requests.post.assert_called_once_with(token_url, data=payload, headers=headers) assert result == expected