From 288348ca872523849d9807f66869adf9767a7446 Mon Sep 17 00:00:00 2001 From: Vincent <97131062+vincbeck@users.noreply.github.com> Date: Tue, 19 Nov 2024 14:57:14 -0500 Subject: [PATCH] Set up JWT token authentication in Fast APIs (#42634) --- airflow/api_fastapi/app.py | 45 +++++++++ airflow/api_fastapi/core_api/security.py | 77 +++++++++++++++ airflow/auth/managers/base_auth_manager.py | 36 ++++--- .../managers/simple/simple_auth_manager.py | 69 +++++++++---- airflow/auth/managers/simple/views/auth.py | 15 ++- airflow/config_templates/config.yml | 20 ++++ airflow/configuration.py | 1 + docs/spelling_wordlist.txt | 1 + .../aws/auth_manager/aws_auth_manager.py | 3 +- .../fab/auth_manager/fab_auth_manager.py | 32 ++++-- .../fab/auth_manager/test_fab_auth_manager.py | 22 ++++- tests/api_fastapi/core_api/test_security.py | 99 +++++++++++++++++++ .../simple/test_simple_auth_manager.py | 15 ++- tests/auth/managers/simple/views/test_auth.py | 9 +- tests/auth/managers/test_base_auth_manager.py | 12 ++- tests/core/test_configuration.py | 1 + 16 files changed, 402 insertions(+), 55 deletions(-) create mode 100644 airflow/api_fastapi/core_api/security.py create mode 100644 tests/api_fastapi/core_api/test_security.py diff --git a/airflow/api_fastapi/app.py b/airflow/api_fastapi/app.py index 9ddd97b6cbe04..4bf6ae9f6b77c 100644 --- a/airflow/api_fastapi/app.py +++ b/airflow/api_fastapi/app.py @@ -24,10 +24,14 @@ from airflow.api_fastapi.core_api.app import init_config, init_dag_bag, init_plugins, init_views from airflow.api_fastapi.execution_api.app import create_task_execution_api_app +from airflow.auth.managers.base_auth_manager import BaseAuthManager +from airflow.configuration import conf +from airflow.exceptions import AirflowConfigException log = logging.getLogger(__name__) app: FastAPI | None = None +auth_manager: BaseAuthManager | None = None @asynccontextmanager @@ -57,6 +61,7 @@ def create_app(apps: str = "all") -> FastAPI: init_dag_bag(app) init_views(app) init_plugins(app) + init_auth_manager() if "execution" in apps_list or "all" in apps_list: task_exec_api_app = create_task_execution_api_app(app) @@ -79,3 +84,43 @@ def purge_cached_app() -> None: """Remove the cached version of the app in global state.""" global app app = None + + +def get_auth_manager_cls() -> type[BaseAuthManager]: + """ + Return just the auth manager class without initializing it. + + Useful to save execution time if only static methods need to be called. + """ + auth_manager_cls = conf.getimport(section="core", key="auth_manager") + + if not auth_manager_cls: + raise AirflowConfigException( + "No auth manager defined in the config. " + "Please specify one using section/key [core/auth_manager]." + ) + + return auth_manager_cls + + +def init_auth_manager() -> BaseAuthManager: + """ + Initialize the auth manager. + + Import the user manager class and instantiate it. + """ + global auth_manager + auth_manager_cls = get_auth_manager_cls() + auth_manager = auth_manager_cls() + auth_manager.init() + return auth_manager + + +def get_auth_manager() -> BaseAuthManager: + """Return the auth manager, provided it's been initialized before.""" + if auth_manager is None: + raise RuntimeError( + "Auth Manager has not been initialized yet. " + "The `init_auth_manager` method needs to be called first." + ) + return auth_manager diff --git a/airflow/api_fastapi/core_api/security.py b/airflow/api_fastapi/core_api/security.py new file mode 100644 index 0000000000000..ede628e04aa70 --- /dev/null +++ b/airflow/api_fastapi/core_api/security.py @@ -0,0 +1,77 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from functools import cache +from typing import Any, Callable + +from fastapi import Depends, HTTPException +from fastapi.security import OAuth2PasswordBearer +from jwt import InvalidTokenError +from typing_extensions import Annotated + +from airflow.api_fastapi.app import get_auth_manager +from airflow.auth.managers.base_auth_manager import ResourceMethod +from airflow.auth.managers.models.base_user import BaseUser +from airflow.auth.managers.models.resource_details import DagAccessEntity, DagDetails +from airflow.configuration import conf +from airflow.utils.jwt_signer import JWTSigner + +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") + + +@cache +def get_signer() -> JWTSigner: + return JWTSigner( + secret_key=conf.get("api", "auth_jwt_secret"), + expiration_time_in_seconds=conf.getint("api", "auth_jwt_expiration_time"), + audience="front-apis", + ) + + +def get_user(token_str: Annotated[str, Depends(oauth2_scheme)]) -> BaseUser: + try: + signer = get_signer() + payload: dict[str, Any] = signer.verify_token(token_str) + return get_auth_manager().deserialize_user(payload) + except InvalidTokenError: + raise HTTPException(403, "Forbidden") + + +def requires_access_dag(method: ResourceMethod, access_entity: DagAccessEntity | None = None) -> Callable: + def inner( + dag_id: str | None = None, + user: Annotated[BaseUser | None, Depends(get_user)] = None, + ) -> None: + def callback(): + return get_auth_manager().is_authorized_dag( + method=method, access_entity=access_entity, details=DagDetails(id=dag_id), user=user + ) + + _requires_access( + is_authorized_callback=callback, + ) + + return inner + + +def _requires_access( + *, + is_authorized_callback: Callable[[], bool], +) -> None: + if not is_authorized_callback(): + raise HTTPException(403, "Forbidden") diff --git a/airflow/auth/managers/base_auth_manager.py b/airflow/auth/managers/base_auth_manager.py index 69b5969c827c6..028d4dadb1326 100644 --- a/airflow/auth/managers/base_auth_manager.py +++ b/airflow/auth/managers/base_auth_manager.py @@ -19,11 +19,12 @@ from abc import abstractmethod from functools import cached_property -from typing import TYPE_CHECKING, Container, Literal, Sequence +from typing import TYPE_CHECKING, Any, Container, Generic, Literal, Sequence, TypeVar from flask_appbuilder.menu import MenuItem from sqlalchemy import select +from airflow.auth.managers.models.base_user import BaseUser from airflow.auth.managers.models.resource_details import ( DagDetails, ) @@ -37,7 +38,6 @@ from flask import Blueprint from sqlalchemy.orm import Session - from airflow.auth.managers.models.base_user import BaseUser from airflow.auth.managers.models.batch_apis import ( IsAuthorizedConnectionRequest, IsAuthorizedDagRequest, @@ -59,8 +59,10 @@ ResourceMethod = Literal["GET", "POST", "PUT", "DELETE", "MENU"] +T = TypeVar("T", bound=BaseUser) -class BaseAuthManager(LoggingMixin): + +class BaseAuthManager(Generic[T], LoggingMixin): """ Class to derive in order to implement concrete auth managers. @@ -69,7 +71,7 @@ class BaseAuthManager(LoggingMixin): :param appbuilder: the flask app builder """ - def __init__(self, appbuilder: AirflowAppBuilder) -> None: + def __init__(self, appbuilder: AirflowAppBuilder | None = None) -> None: super().__init__() self.appbuilder = appbuilder @@ -93,9 +95,17 @@ def get_user_display_name(self) -> str: return self.get_user_name() @abstractmethod - def get_user(self) -> BaseUser | None: + def get_user(self) -> T | None: """Return the user associated to the user in session.""" + @abstractmethod + def deserialize_user(self, token: dict[str, Any]) -> T: + """Create a user object from dict.""" + + @abstractmethod + def serialize_user(self, user: T) -> dict[str, Any]: + """Create a dict from a user object.""" + def get_user_id(self) -> str | None: """Return the user ID associated to the user in session.""" user = self.get_user() @@ -132,7 +142,7 @@ def is_authorized_configuration( *, method: ResourceMethod, details: ConfigurationDetails | None = None, - user: BaseUser | None = None, + user: T | None = None, ) -> bool: """ Return whether the user is authorized to perform a given action on configuration. @@ -148,7 +158,7 @@ def is_authorized_connection( *, method: ResourceMethod, details: ConnectionDetails | None = None, - user: BaseUser | None = None, + user: T | None = None, ) -> bool: """ Return whether the user is authorized to perform a given action on a connection. @@ -165,7 +175,7 @@ def is_authorized_dag( method: ResourceMethod, access_entity: DagAccessEntity | None = None, details: DagDetails | None = None, - user: BaseUser | None = None, + user: T | None = None, ) -> bool: """ Return whether the user is authorized to perform a given action on a DAG. @@ -183,7 +193,7 @@ def is_authorized_asset( *, method: ResourceMethod, details: AssetDetails | None = None, - user: BaseUser | None = None, + user: T | None = None, ) -> bool: """ Return whether the user is authorized to perform a given action on an asset. @@ -199,7 +209,7 @@ def is_authorized_pool( *, method: ResourceMethod, details: PoolDetails | None = None, - user: BaseUser | None = None, + user: T | None = None, ) -> bool: """ Return whether the user is authorized to perform a given action on a pool. @@ -215,7 +225,7 @@ def is_authorized_variable( *, method: ResourceMethod, details: VariableDetails | None = None, - user: BaseUser | None = None, + user: T | None = None, ) -> bool: """ Return whether the user is authorized to perform a given action on a variable. @@ -230,7 +240,7 @@ def is_authorized_view( self, *, access_view: AccessView, - user: BaseUser | None = None, + user: T | None = None, ) -> bool: """ Return whether the user is authorized to access a read-only state of the installation. @@ -241,7 +251,7 @@ def is_authorized_view( @abstractmethod def is_authorized_custom_view( - self, *, method: ResourceMethod | str, resource_name: str, user: BaseUser | None = None + self, *, method: ResourceMethod | str, resource_name: str, user: T | None = None ): """ Return whether the user is authorized to perform a given action on a custom view. diff --git a/airflow/auth/managers/simple/simple_auth_manager.py b/airflow/auth/managers/simple/simple_auth_manager.py index 78dccf7c2a980..48baa02e7c75f 100644 --- a/airflow/auth/managers/simple/simple_auth_manager.py +++ b/airflow/auth/managers/simple/simple_auth_manager.py @@ -22,7 +22,7 @@ import random from collections import namedtuple from enum import Enum -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from flask import session, url_for from termcolor import colored @@ -33,7 +33,6 @@ from airflow.configuration import AIRFLOW_HOME if TYPE_CHECKING: - from airflow.auth.managers.models.base_user import BaseUser from airflow.auth.managers.models.resource_details import ( AccessView, AssetDetails, @@ -68,7 +67,7 @@ class SimpleAuthManagerRole(namedtuple("SimpleAuthManagerRole", "name order"), E ADMIN = "ADMIN", 3 -class SimpleAuthManager(BaseAuthManager): +class SimpleAuthManager(BaseAuthManager[SimpleAuthManagerUser]): """ Simple auth manager. @@ -89,6 +88,8 @@ def get_generated_password_file() -> str: ) def init(self) -> None: + if not self.appbuilder: + return user_passwords_from_file = {} # Read passwords from file @@ -115,8 +116,9 @@ def init(self) -> None: file.write(json.dumps(self.passwords)) def is_logged_in(self) -> bool: - return "user" in session or self.appbuilder.get_app.config.get( - "SIMPLE_AUTH_MANAGER_ALL_ADMINS", False + return "user" in session or ( + self.appbuilder is not None + and self.appbuilder.get_app.config.get("SIMPLE_AUTH_MANAGER_ALL_ADMINS", False) ) def get_url_login(self, **kwargs) -> str: @@ -128,28 +130,34 @@ def get_url_logout(self) -> str: def get_user(self) -> SimpleAuthManagerUser | None: if not self.is_logged_in(): return None - if self.appbuilder.get_app.config.get("SIMPLE_AUTH_MANAGER_ALL_ADMINS", False): + if self.appbuilder and self.appbuilder.get_app.config.get("SIMPLE_AUTH_MANAGER_ALL_ADMINS", False): return SimpleAuthManagerUser(username="anonymous", role="admin") else: return session["user"] + def deserialize_user(self, token: dict[str, Any]) -> SimpleAuthManagerUser: + return SimpleAuthManagerUser(username=token["username"], role=token["role"]) + + def serialize_user(self, user: SimpleAuthManagerUser) -> dict[str, Any]: + return {"username": user.username, "role": user.role} + def is_authorized_configuration( self, *, method: ResourceMethod, details: ConfigurationDetails | None = None, - user: BaseUser | None = None, + user: SimpleAuthManagerUser | None = None, ) -> bool: - return self._is_authorized(method=method, allow_role=SimpleAuthManagerRole.OP) + return self._is_authorized(method=method, allow_role=SimpleAuthManagerRole.OP, user=user) def is_authorized_connection( self, *, method: ResourceMethod, details: ConnectionDetails | None = None, - user: BaseUser | None = None, + user: SimpleAuthManagerUser | None = None, ) -> bool: - return self._is_authorized(method=method, allow_role=SimpleAuthManagerRole.OP) + return self._is_authorized(method=method, allow_role=SimpleAuthManagerRole.OP, user=user) def is_authorized_dag( self, @@ -157,46 +165,65 @@ def is_authorized_dag( method: ResourceMethod, access_entity: DagAccessEntity | None = None, details: DagDetails | None = None, - user: BaseUser | None = None, + user: SimpleAuthManagerUser | None = None, ) -> bool: return self._is_authorized( method=method, allow_get_role=SimpleAuthManagerRole.VIEWER, allow_role=SimpleAuthManagerRole.USER, + user=user, ) def is_authorized_asset( - self, *, method: ResourceMethod, details: AssetDetails | None = None, user: BaseUser | None = None + self, + *, + method: ResourceMethod, + details: AssetDetails | None = None, + user: SimpleAuthManagerUser | None = None, ) -> bool: return self._is_authorized( method=method, allow_get_role=SimpleAuthManagerRole.VIEWER, allow_role=SimpleAuthManagerRole.OP, + user=user, ) def is_authorized_pool( - self, *, method: ResourceMethod, details: PoolDetails | None = None, user: BaseUser | None = None + self, + *, + method: ResourceMethod, + details: PoolDetails | None = None, + user: SimpleAuthManagerUser | None = None, ) -> bool: return self._is_authorized( method=method, allow_get_role=SimpleAuthManagerRole.VIEWER, allow_role=SimpleAuthManagerRole.OP, + user=user, ) def is_authorized_variable( - self, *, method: ResourceMethod, details: VariableDetails | None = None, user: BaseUser | None = None + self, + *, + method: ResourceMethod, + details: VariableDetails | None = None, + user: SimpleAuthManagerUser | None = None, ) -> bool: - return self._is_authorized(method=method, allow_role=SimpleAuthManagerRole.OP) + return self._is_authorized(method=method, allow_role=SimpleAuthManagerRole.OP, user=user) - def is_authorized_view(self, *, access_view: AccessView, user: BaseUser | None = None) -> bool: - return self._is_authorized(method="GET", allow_role=SimpleAuthManagerRole.VIEWER) + def is_authorized_view( + self, *, access_view: AccessView, user: SimpleAuthManagerUser | None = None + ) -> bool: + return self._is_authorized(method="GET", allow_role=SimpleAuthManagerRole.VIEWER, user=user) def is_authorized_custom_view( - self, *, method: ResourceMethod | str, resource_name: str, user: BaseUser | None = None + self, *, method: ResourceMethod | str, resource_name: str, user: SimpleAuthManagerUser | None = None ): - return self._is_authorized(method="GET", allow_role=SimpleAuthManagerRole.VIEWER) + return self._is_authorized(method="GET", allow_role=SimpleAuthManagerRole.VIEWER, user=user) def register_views(self) -> None: + if not self.appbuilder: + return self.appbuilder.add_view_no_menu( SimpleAuthManagerAuthenticationViews( users=self.appbuilder.get_app.config.get("SIMPLE_AUTH_MANAGER_USERS", []), @@ -210,6 +237,7 @@ def _is_authorized( method: ResourceMethod, allow_role: SimpleAuthManagerRole, allow_get_role: SimpleAuthManagerRole | None = None, + user: SimpleAuthManagerUser | None = None, ): """ Return whether the user is authorized to access a given resource. @@ -219,8 +247,9 @@ def _is_authorized( equal than this role, they have access :param allow_get_role: minimal role giving access to the resource, if the user's role is greater or equal than this role, they have access. If not provided, ``allow_role`` is used + :param user: the user to check the authorization for. If not provided, the current user is used """ - user = self.get_user() + user = user or self.get_user() if not user: return False diff --git a/airflow/auth/managers/simple/views/auth.py b/airflow/auth/managers/simple/views/auth.py index 8ab02d0a01567..6e4cf0c399416 100644 --- a/airflow/auth/managers/simple/views/auth.py +++ b/airflow/auth/managers/simple/views/auth.py @@ -23,8 +23,10 @@ from airflow.auth.managers.simple.user import SimpleAuthManagerUser from airflow.configuration import conf +from airflow.utils.jwt_signer import JWTSigner from airflow.utils.state import State from airflow.www.app import csrf +from airflow.www.extensions.init_auth_manager import get_auth_manager from airflow.www.views import AirflowBaseView logger = logging.getLogger(__name__) @@ -80,9 +82,18 @@ def login_submit(self): if not username or not password or len(found_users) == 0: return redirect(url_for("SimpleAuthManagerAuthenticationViews.login", error=["1"])) - session["user"] = SimpleAuthManagerUser( + user = SimpleAuthManagerUser( username=username, role=found_users[0]["role"], ) + # Will be removed once Airflow uses the new UI + session["user"] = user - return redirect(url_for("Airflow.index")) + signer = JWTSigner( + secret_key=conf.get("api", "auth_jwt_secret"), + expiration_time_in_seconds=conf.getint("api", "auth_jwt_expiration_time"), + audience="front-apis", + ) + token = signer.generate_signed_token(get_auth_manager().serialize_user(user)) + + return redirect(url_for("Airflow.index", token=token)) diff --git a/airflow/config_templates/config.yml b/airflow/config_templates/config.yml index 04ac0d802a999..4a7fe5f189719 100644 --- a/airflow/config_templates/config.yml +++ b/airflow/config_templates/config.yml @@ -1404,6 +1404,26 @@ api: version_added: 2.7.0 example: ~ default: "False" + auth_jwt_secret: + description: | + Secret key used to encode and decode JWT tokens to authenticate to public and private APIs. + It should be as random as possible. However, when running more than 1 instances of API services, + make sure all of them use the same ``jwt_secret`` otherwise calls will fail on authentication. + version_added: 3.0.0 + type: string + sensitive: true + example: ~ + default: "{JWT_SECRET_KEY}" + auth_jwt_expiration_time: + description: | + Number in seconds until the JWT token used for authentication expires. When the token expires, + all API calls using this token will fail on authentication. + Make sure that time on ALL the machines that you run airflow components on is synchronized + (for example using ntpd) otherwise you might get "forbidden" errors. + version_added: 3.0.0 + type: integer + example: ~ + default: "86400" lineage: description: ~ options: diff --git a/airflow/configuration.py b/airflow/configuration.py index bc808d6bfc28e..521af6cbe320d 100644 --- a/airflow/configuration.py +++ b/airflow/configuration.py @@ -2143,6 +2143,7 @@ def initialize_auth_manager() -> BaseAuthManager: TEST_PLUGINS_FOLDER = os.path.join(AIRFLOW_HOME, "plugins") SECRET_KEY = b64encode(os.urandom(16)).decode("utf-8") +JWT_SECRET_KEY = b64encode(os.urandom(16)).decode("utf-8") FERNET_KEY = "" # Set only if needed when generating a new file WEBSERVER_CONFIG = "" # Set by initialize_config diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 32787428cc5ce..b2df194de0dfd 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -898,6 +898,7 @@ Jupyter jupyter jupytercmd JWT +jwt Kafka kafka Kalibrr diff --git a/providers/src/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py b/providers/src/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py index 414907961ce71..8271aad330229 100644 --- a/providers/src/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py +++ b/providers/src/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py @@ -430,7 +430,8 @@ def get_cli_commands() -> list[CLICommand]: ] def register_views(self) -> None: - self.appbuilder.add_view_no_menu(AwsAuthManagerAuthenticationViews()) + if self.appbuilder: + self.appbuilder.add_view_no_menu(AwsAuthManagerAuthenticationViews()) @staticmethod def _get_menu_item_request(resource_name: str) -> IsAuthorizedRequest: diff --git a/providers/src/airflow/providers/fab/auth_manager/fab_auth_manager.py b/providers/src/airflow/providers/fab/auth_manager/fab_auth_manager.py index e93e440f5ddfe..adebe30dd534a 100644 --- a/providers/src/airflow/providers/fab/auth_manager/fab_auth_manager.py +++ b/providers/src/airflow/providers/fab/auth_manager/fab_auth_manager.py @@ -21,7 +21,7 @@ import warnings from functools import cached_property from pathlib import Path -from typing import TYPE_CHECKING, Container +from typing import TYPE_CHECKING, Any, Container import packaging.version from connexion import FlaskApi @@ -82,7 +82,7 @@ RESOURCE_WEBSITE, RESOURCE_XCOM, ) -from airflow.utils.session import NEW_SESSION, provide_session +from airflow.utils.session import NEW_SESSION, create_session, provide_session from airflow.utils.yaml import safe_load from airflow.version import version from airflow.www.constants import SWAGGER_BUNDLE, SWAGGER_ENABLED @@ -131,13 +131,18 @@ } -class FabAuthManager(BaseAuthManager): +class FabAuthManager(BaseAuthManager[User]): """ Flask-AppBuilder auth manager. This auth manager is responsible for providing a backward compatible user management experience to users. """ + def init(self) -> None: + """Run operations when Airflow is initializing.""" + if self.appbuilder: + self._sync_appbuilder_roles() + @staticmethod def get_cli_commands() -> list[CLICommand]: """Vends CLI commands to be included in Airflow CLI.""" @@ -199,9 +204,12 @@ def get_user(self) -> User: return current_user - def init(self) -> None: - """Run operations when Airflow is initializing.""" - self._sync_appbuilder_roles() + def deserialize_user(self, token: dict[str, Any]) -> User: + with create_session() as session: + return session.get(User, token["id"]) + + def serialize_user(self, user: User) -> dict[str, Any]: + return {"id": user.id} def is_logged_in(self) -> bool: """Return whether the user is logged in.""" @@ -209,8 +217,10 @@ def is_logged_in(self) -> bool: if Version(Version(version).base_version) < Version("3.0.0"): return not user.is_anonymous and user.is_active else: - return self.appbuilder.get_app.config.get("AUTH_ROLE_PUBLIC", None) or ( - not user.is_anonymous and user.is_active + return ( + self.appbuilder + and self.appbuilder.get_app.config.get("AUTH_ROLE_PUBLIC", None) + or (not user.is_anonymous and user.is_active) ) def is_authorized_configuration( @@ -376,6 +386,9 @@ def security_manager(self) -> FabAirflowSecurityManagerOverride: FabAirflowSecurityManagerOverride, ) + if not self.appbuilder: + raise AirflowException("AppBuilder is not initialized.") + sm_from_config = self.appbuilder.get_app.config.get("SECURITY_MANAGER_CLASS") if sm_from_config: if not issubclass(sm_from_config, FabAirflowSecurityManagerOverride): @@ -547,6 +560,9 @@ def _get_root_dag_id(self, dag_id: str) -> str: :meta private: """ + if not self.appbuilder: + raise AirflowException("AppBuilder is not initialized.") + if "." in dag_id and hasattr(DagModel, "root_dag_id"): return self.appbuilder.get_session.scalar( select(DagModel.dag_id, DagModel.root_dag_id).where(DagModel.dag_id == dag_id).limit(1) diff --git a/providers/tests/fab/auth_manager/test_fab_auth_manager.py b/providers/tests/fab/auth_manager/test_fab_auth_manager.py index d298f7667eaaf..1994e910f9ebd 100644 --- a/providers/tests/fab/auth_manager/test_fab_auth_manager.py +++ b/providers/tests/fab/auth_manager/test_fab_auth_manager.py @@ -27,6 +27,8 @@ from airflow.exceptions import AirflowConfigException, AirflowException +from providers.tests.fab.auth_manager.api_endpoints.api_connexion_utils import create_user + try: from airflow.auth.managers.models.resource_details import AccessView, DagAccessEntity, DagDetails except ImportError: @@ -141,6 +143,16 @@ def test_get_user_from_flask_g(self, mock_current_user, minimal_app_for_auth_api with user_set(minimal_app_for_auth_api, flask_g_user): assert auth_manager.get_user() == flask_g_user + def test_deserialize_user(self, flask_app, auth_manager_with_appbuilder): + user = create_user(flask_app, "test") + result = auth_manager_with_appbuilder.deserialize_user({"id": user.id}) + assert user == result + + def test_serialize_user(self, flask_app, auth_manager_with_appbuilder): + user = create_user(flask_app, "test") + result = auth_manager_with_appbuilder.serialize_user(user) + assert result == {"id": user.id} + @pytest.mark.db_test @mock.patch.object(FabAuthManager, "get_user") def test_is_logged_in(self, mock_get_user, auth_manager_with_appbuilder): @@ -338,11 +350,17 @@ def test_is_authorized(self, api_name, method, user_permissions, expected_result ], ) def test_is_authorized_dag( - self, method, dag_access_entity, dag_details, user_permissions, expected_result, auth_manager + self, + method, + dag_access_entity, + dag_details, + user_permissions, + expected_result, + auth_manager_with_appbuilder, ): user = Mock() user.perms = user_permissions - result = auth_manager.is_authorized_dag( + result = auth_manager_with_appbuilder.is_authorized_dag( method=method, access_entity=dag_access_entity, details=dag_details, user=user ) assert result == expected_result diff --git a/tests/api_fastapi/core_api/test_security.py b/tests/api_fastapi/core_api/test_security.py new file mode 100644 index 0000000000000..90bf3f647bba3 --- /dev/null +++ b/tests/api_fastapi/core_api/test_security.py @@ -0,0 +1,99 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest.mock import Mock, patch + +import pytest +from fastapi import HTTPException +from jwt import InvalidTokenError + +from airflow.api_fastapi.app import create_app +from airflow.api_fastapi.core_api.security import get_user, requires_access_dag +from airflow.auth.managers.models.resource_details import DagAccessEntity +from airflow.auth.managers.simple.user import SimpleAuthManagerUser + +from tests_common.test_utils.config import conf_vars + + +class TestFastApiSecurity: + @classmethod + def setup_class(cls): + with conf_vars( + { + ( + "core", + "auth_manager", + ): "airflow.auth.managers.simple.simple_auth_manager.SimpleAuthManager", + } + ): + create_app() + + @patch("airflow.api_fastapi.core_api.security.get_signer") + @patch("airflow.api_fastapi.core_api.security.get_auth_manager") + def test_get_user(self, mock_get_auth_manager, mock_get_signer): + token_str = "test-token" + user_dict = {"user": "XXXXXXXXX"} + user = SimpleAuthManagerUser(username="username", role="admin") + + auth_manager = Mock() + auth_manager.deserialize_user.return_value = user + mock_get_auth_manager.return_value = auth_manager + + signer = Mock() + signer.verify_token.return_value = user_dict + mock_get_signer.return_value = signer + + result = get_user(token_str) + + signer.verify_token.assert_called_once_with(token_str) + auth_manager.deserialize_user.assert_called_once_with(user_dict) + assert result == user + + @patch("airflow.api_fastapi.core_api.security.get_signer") + def test_get_user_unsuccessful(self, mock_get_signer): + token_str = "test-token" + + signer = Mock() + signer.verify_token.side_effect = InvalidTokenError() + mock_get_signer.return_value = signer + + with pytest.raises(HTTPException, match="Forbidden"): + get_user(token_str) + + signer.verify_token.assert_called_once_with(token_str) + + @patch("airflow.api_fastapi.core_api.security.get_auth_manager") + def test_requires_access_dag_authorized(self, mock_get_auth_manager): + auth_manager = Mock() + auth_manager.is_authorized_dag.return_value = True + mock_get_auth_manager.return_value = auth_manager + + requires_access_dag("GET", DagAccessEntity.CODE)("dag-id", Mock()) + + auth_manager.is_authorized_dag.assert_called_once() + + @patch("airflow.api_fastapi.core_api.security.get_auth_manager") + def test_requires_access_dag_unauthorized(self, mock_get_auth_manager): + auth_manager = Mock() + auth_manager.is_authorized_dag.return_value = False + mock_get_auth_manager.return_value = auth_manager + + with pytest.raises(HTTPException, match="Forbidden"): + requires_access_dag("GET", DagAccessEntity.CODE)("dag-id", Mock()) + + auth_manager.is_authorized_dag.assert_called_once() diff --git a/tests/auth/managers/simple/test_simple_auth_manager.py b/tests/auth/managers/simple/test_simple_auth_manager.py index 434c0d60fcc76..07289f6f0020c 100644 --- a/tests/auth/managers/simple/test_simple_auth_manager.py +++ b/tests/auth/managers/simple/test_simple_auth_manager.py @@ -132,6 +132,16 @@ def test_get_user_return_none_when_not_logged_in(self, mock_is_logged_in, auth_m assert result is None + def test_deserialize_user(self, auth_manager): + result = auth_manager.deserialize_user({"username": "test", "role": "admin"}) + assert result.username == "test" + assert result.role == "admin" + + def test_serialize_user(self, auth_manager): + user = SimpleAuthManagerUser(username="test", role="admin") + result = auth_manager.serialize_user(user) + assert result == {"username": "test", "role": "admin"} + @pytest.mark.db_test @patch.object(SimpleAuthManager, "is_logged_in") @pytest.mark.parametrize( @@ -280,10 +290,7 @@ def test_is_authorized_methods_viewer_role_required_for_get( assert getattr(auth_manager_with_appbuilder, api)(method=method) is result @pytest.mark.db_test - @patch( - "airflow.providers.amazon.aws.auth_manager.views.auth.conf.get_mandatory_value", return_value="test" - ) - def test_register_views(self, _, auth_manager_with_appbuilder): + def test_register_views(self, auth_manager_with_appbuilder): auth_manager_with_appbuilder.appbuilder.add_view_no_menu = Mock() auth_manager_with_appbuilder.register_views() auth_manager_with_appbuilder.appbuilder.add_view_no_menu.assert_called_once() diff --git a/tests/auth/managers/simple/views/test_auth.py b/tests/auth/managers/simple/views/test_auth.py index cd4628c2d02b6..e3e7b29e2cc1b 100644 --- a/tests/auth/managers/simple/views/test_auth.py +++ b/tests/auth/managers/simple/views/test_auth.py @@ -17,6 +17,7 @@ from __future__ import annotations import json +from unittest.mock import Mock, patch import pytest from flask import session, url_for @@ -67,11 +68,15 @@ def test_logout_redirects_to_login_and_clear_user(self, simple_app): "username, password, is_successful", [("test", "test", True), ("test", "test2", False), ("", "", False)], ) - def test_login_submit(self, simple_app, username, password, is_successful): + @patch("airflow.auth.managers.simple.views.auth.JWTSigner") + def test_login_submit(self, mock_jwt_signer, simple_app, username, password, is_successful): + signer = Mock() + signer.generate_signed_token.return_value = "token" + mock_jwt_signer.return_value = signer with simple_app.test_client() as client: response = client.post("/login_submit", data={"username": username, "password": password}) assert response.status_code == 302 if is_successful: - assert response.location == url_for("Airflow.index") + assert response.location == url_for("Airflow.index", token="token") else: assert response.location == url_for("SimpleAuthManagerAuthenticationViews.login", error=["1"]) diff --git a/tests/auth/managers/test_base_auth_manager.py b/tests/auth/managers/test_base_auth_manager.py index 82efe20048b71..c62076a4654d0 100644 --- a/tests/auth/managers/test_base_auth_manager.py +++ b/tests/auth/managers/test_base_auth_manager.py @@ -16,13 +16,14 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from unittest.mock import MagicMock, Mock, patch import pytest from flask_appbuilder.menu import Menu from airflow.auth.managers.base_auth_manager import BaseAuthManager, ResourceMethod +from airflow.auth.managers.models.base_user import BaseUser from airflow.auth.managers.models.resource_details import ( ConnectionDetails, DagDetails, @@ -32,7 +33,6 @@ from airflow.exceptions import AirflowException if TYPE_CHECKING: - from airflow.auth.managers.models.base_user import BaseUser from airflow.auth.managers.models.resource_details import ( AccessView, AssetDetails, @@ -41,10 +41,16 @@ ) -class EmptyAuthManager(BaseAuthManager): +class EmptyAuthManager(BaseAuthManager[BaseUser]): def get_user(self) -> BaseUser: raise NotImplementedError() + def deserialize_user(self, token: dict[str, Any]) -> BaseUser: + raise NotImplementedError() + + def serialize_user(self, user: BaseUser) -> dict[str, Any]: + raise NotImplementedError() + def is_authorized_configuration( self, *, diff --git a/tests/core/test_configuration.py b/tests/core/test_configuration.py index 2167d71ed5c32..1fb3b27057360 100644 --- a/tests/core/test_configuration.py +++ b/tests/core/test_configuration.py @@ -1584,6 +1584,7 @@ def test_sensitive_values(): ("database", "sql_alchemy_conn"), ("core", "fernet_key"), ("core", "internal_api_secret_key"), + ("api", "auth_jwt_secret"), ("webserver", "secret_key"), ("secrets", "backend_kwargs"), ("sentry", "sentry_dsn"),