diff --git a/airflow/www/extensions/init_appbuilder.py b/airflow/www/extensions/init_appbuilder.py index a3b48cc8d9d61..2238066e6e953 100644 --- a/airflow/www/extensions/init_appbuilder.py +++ b/airflow/www/extensions/init_appbuilder.py @@ -212,6 +212,10 @@ def init_app(self, app, session): auth_manager = create_auth_manager() auth_manager.appbuilder = self auth_manager.init() + # FAB auth manager is used in both the old FAB UI and the new React UI backend by Fastapi. + # It can behave differently depending on the application. Setting this flag so that it knows this + # instance is run in FAB. + auth_manager.is_in_fab = True if hasattr(auth_manager, "security_manager"): self.sm = auth_manager.security_manager else: diff --git a/providers/fab/src/airflow/providers/fab/auth_manager/fab_auth_manager.py b/providers/fab/src/airflow/providers/fab/auth_manager/fab_auth_manager.py index c3d086b56e168..698f2a9219f63 100644 --- a/providers/fab/src/airflow/providers/fab/auth_manager/fab_auth_manager.py +++ b/providers/fab/src/airflow/providers/fab/auth_manager/fab_auth_manager.py @@ -146,11 +146,21 @@ class FabAuthManager(BaseAuthManager[User]): appbuilder: AirflowAppBuilder | None = None + is_in_fab: bool = False + """ + Whether the instance is run in FAB or Fastapi. + Can be deleted once the Airflow 2 legacy UI is removed. + """ + def init(self) -> None: """Run operations when Airflow is initializing.""" if self.appbuilder: self._sync_appbuilder_roles() + @cached_property + def fastapi_endpoint(self) -> str: + return conf.get("fastapi", "base_url") + @staticmethod def get_cli_commands() -> list[CLICommand]: """Vends CLI commands to be included in Airflow CLI.""" @@ -443,12 +453,15 @@ def security_manager(self) -> FabAirflowSecurityManagerOverride: def get_url_login(self, **kwargs) -> str: """Return the login page url.""" - if not self.security_manager.auth_view: - raise AirflowException("`auth_view` not defined in the security manager.") - if next_url := kwargs.get("next_url"): - return url_for(f"{self.security_manager.auth_view.endpoint}.login", next=next_url) + if self.is_in_fab: + if not self.security_manager.auth_view: + raise AirflowException("`auth_view` not defined in the security manager.") + if next_url := kwargs.get("next_url"): + return url_for(f"{self.security_manager.auth_view.endpoint}.login", next=next_url) + else: + return url_for(f"{self.security_manager.auth_view.endpoint}.login") else: - return url_for(f"{self.security_manager.auth_view.endpoint}.login") + return f"{self.fastapi_endpoint}/auth/login" def get_url_logout(self): """Return the logout page url.""" diff --git a/providers/fab/tests/provider_tests/fab/auth_manager/test_fab_auth_manager.py b/providers/fab/tests/provider_tests/fab/auth_manager/test_fab_auth_manager.py index 1fef2598ba0e4..a6291f20b572d 100644 --- a/providers/fab/tests/provider_tests/fab/auth_manager/test_fab_auth_manager.py +++ b/providers/fab/tests/provider_tests/fab/auth_manager/test_fab_auth_manager.py @@ -574,26 +574,9 @@ class TestSecurityManager: ): auth_manager_with_appbuilder.security_manager - @pytest.mark.db_test - def test_get_url_login_when_auth_view_not_defined(self, auth_manager_with_appbuilder): - with pytest.raises(AirflowException, match="`auth_view` not defined in the security manager."): - auth_manager_with_appbuilder.get_url_login() - - @pytest.mark.db_test - @mock.patch("airflow.providers.fab.auth_manager.fab_auth_manager.url_for") - def test_get_url_login(self, mock_url_for, auth_manager_with_appbuilder): - auth_manager_with_appbuilder.security_manager.auth_view = Mock() - auth_manager_with_appbuilder.security_manager.auth_view.endpoint = "test_endpoint" - auth_manager_with_appbuilder.get_url_login() - mock_url_for.assert_called_once_with("test_endpoint.login") - - @pytest.mark.db_test - @mock.patch("airflow.providers.fab.auth_manager.fab_auth_manager.url_for") - def test_get_url_login_with_next(self, mock_url_for, auth_manager_with_appbuilder): - auth_manager_with_appbuilder.security_manager.auth_view = Mock() - auth_manager_with_appbuilder.security_manager.auth_view.endpoint = "test_endpoint" - auth_manager_with_appbuilder.get_url_login(next_url="next_url") - mock_url_for.assert_called_once_with("test_endpoint.login", next="next_url") + def test_get_url_login(self, auth_manager): + result = auth_manager.get_url_login() + assert result == "http://localhost:29091/auth/login" @pytest.mark.db_test def test_get_url_logout_when_auth_view_not_defined(self, auth_manager_with_appbuilder):