diff --git a/providers/fab/src/airflow/providers/fab/auth_manager/fab_auth_manager.py b/providers/fab/src/airflow/providers/fab/auth_manager/fab_auth_manager.py index 698f2a9219f63..e08a54ecf5105 100644 --- a/providers/fab/src/airflow/providers/fab/auth_manager/fab_auth_manager.py +++ b/providers/fab/src/airflow/providers/fab/auth_manager/fab_auth_manager.py @@ -185,20 +185,14 @@ def get_cli_commands() -> list[CLICommand]: return commands def get_fastapi_app(self) -> FastAPI | None: - flask_blueprint = self.get_api_endpoints() - - if not flask_blueprint: - return None - flask_app = create_app(enable_plugins=False) - flask_app.register_blueprint(flask_blueprint) app = FastAPI( title="FAB auth manager API", description=( "This is FAB auth manager API. This API is only available if the auth manager used in " "the Airflow environment is FAB auth manager. " - "This API provides endpoints to manager users and permissions managed by the FAB auth " + "This API provides endpoints to manage users and permissions managed by the FAB auth " "manager." ), ) diff --git a/providers/fab/src/airflow/providers/fab/auth_manager/security_manager/override.py b/providers/fab/src/airflow/providers/fab/auth_manager/security_manager/override.py index 6438fea6282de..23b23fb51259e 100644 --- a/providers/fab/src/airflow/providers/fab/auth_manager/security_manager/override.py +++ b/providers/fab/src/airflow/providers/fab/auth_manager/security_manager/override.py @@ -110,9 +110,9 @@ from airflow.providers.fab.www.security import permissions from airflow.providers.fab.www.security_manager import AirflowSecurityManagerV2 from airflow.providers.fab.www.session import ( + AirflowDatabaseSessionInterface, AirflowDatabaseSessionInterface as FabAirflowDatabaseSessionInterface, ) -from airflow.www.session import AirflowDatabaseSessionInterface if TYPE_CHECKING: from airflow.providers.fab.www.security.permissions import RESOURCE_ASSET diff --git a/providers/fab/src/airflow/providers/fab/www/api_connexion/exceptions.py b/providers/fab/src/airflow/providers/fab/www/api_connexion/exceptions.py index feaf38e8fd09b..ef2e2ab9b4bbc 100644 --- a/providers/fab/src/airflow/providers/fab/www/api_connexion/exceptions.py +++ b/providers/fab/src/airflow/providers/fab/www/api_connexion/exceptions.py @@ -17,12 +17,16 @@ from __future__ import annotations from http import HTTPStatus -from typing import Any +from typing import TYPE_CHECKING, Any -from connexion import ProblemException +import werkzeug +from connexion import FlaskApi, ProblemException, problem from airflow.utils.docs import get_docs_url +if TYPE_CHECKING: + import flask + doc_link = get_docs_url("stable-rest-api-ref.html") EXCEPTIONS_LINK_MAP = { @@ -36,6 +40,39 @@ } +def common_error_handler(exception: BaseException) -> flask.Response: + """Use to capture connexion exceptions and add link to the type field.""" + if isinstance(exception, ProblemException): + link = EXCEPTIONS_LINK_MAP.get(exception.status) + if link: + response = problem( + status=exception.status, + title=exception.title, + detail=exception.detail, + type=link, + instance=exception.instance, + headers=exception.headers, + ext=exception.ext, + ) + else: + response = problem( + status=exception.status, + title=exception.title, + detail=exception.detail, + type=exception.type, + instance=exception.instance, + headers=exception.headers, + ext=exception.ext, + ) + else: + if not isinstance(exception, werkzeug.exceptions.HTTPException): + exception = werkzeug.exceptions.InternalServerError() + + response = problem(title=exception.name, detail=exception.description, status=exception.code) + + return FlaskApi.get_response(response) + + class NotFound(ProblemException): """Raise when the object cannot be found.""" diff --git a/providers/fab/src/airflow/providers/fab/www/app.py b/providers/fab/src/airflow/providers/fab/www/app.py index d67ec81d0337d..036106e545e22 100644 --- a/providers/fab/src/airflow/providers/fab/www/app.py +++ b/providers/fab/src/airflow/providers/fab/www/app.py @@ -32,7 +32,13 @@ from airflow.providers.fab.www.extensions.init_jinja_globals import init_jinja_globals from airflow.providers.fab.www.extensions.init_manifest_files import configure_manifest_files from airflow.providers.fab.www.extensions.init_security import init_api_auth, init_xframe_protection -from airflow.providers.fab.www.extensions.init_views import init_error_handlers, init_plugins +from airflow.providers.fab.www.extensions.init_session import init_airflow_session_interface +from airflow.providers.fab.www.extensions.init_views import ( + init_api_auth_provider, + init_api_error_handlers, + init_error_handlers, + init_plugins, +) app: Flask | None = None @@ -58,6 +64,8 @@ def create_app(enable_plugins: bool): if "SQLALCHEMY_ENGINE_OPTIONS" not in flask_app.config: flask_app.config["SQLALCHEMY_ENGINE_OPTIONS"] = settings.prepare_engine_args() + csrf.init_app(flask_app) + db = SQLA() db.session = settings.Session db.init_app(flask_app) @@ -68,11 +76,15 @@ def create_app(enable_plugins: bool): with flask_app.app_context(): init_appbuilder(flask_app, enable_plugins=enable_plugins) + init_error_handlers(flask_app) if enable_plugins: init_plugins(flask_app) - init_error_handlers(flask_app) + else: + init_api_auth_provider(flask_app) + init_api_error_handlers(flask_app) init_jinja_globals(flask_app, enable_plugins=enable_plugins) init_xframe_protection(flask_app) + init_airflow_session_interface(flask_app) return flask_app diff --git a/providers/fab/src/airflow/providers/fab/www/extensions/init_appbuilder.py b/providers/fab/src/airflow/providers/fab/www/extensions/init_appbuilder.py index 8a711a0164cbf..912e1533ed6fe 100644 --- a/providers/fab/src/airflow/providers/fab/www/extensions/init_appbuilder.py +++ b/providers/fab/src/airflow/providers/fab/www/extensions/init_appbuilder.py @@ -210,6 +210,15 @@ def init_app(self, app, session): self._add_admin_views() self._add_addon_views() self._init_extension(app) + self._swap_url_filter() + + def _swap_url_filter(self): + """Use our url filtering util function so there is consistency between FAB and Airflow routes.""" + from flask_appbuilder.security import views as fab_sec_views + + from airflow.providers.fab.www.views import get_safe_url + + fab_sec_views.get_safe_redirect = get_safe_url def _init_extension(self, app): app.appbuilder = self @@ -518,6 +527,9 @@ def add_view_no_menu(self, baseview, endpoint=None, static_folder=None): def get_url_for_index(self): return url_for(f"{self.indexview.endpoint}.{self.indexview.default_view}") + def get_url_for_login_with(self, next_url: str | None = None) -> str: + return get_auth_manager().get_url_login(next_url=next_url) + def get_url_for_locale(self, lang): return url_for( f"{self.bm.locale_view.endpoint}.{self.bm.locale_view.default_view}", diff --git a/providers/fab/src/airflow/providers/fab/www/extensions/init_views.py b/providers/fab/src/airflow/providers/fab/www/extensions/init_views.py index 588276735fc88..1fbe234bc6dbd 100644 --- a/providers/fab/src/airflow/providers/fab/www/extensions/init_views.py +++ b/providers/fab/src/airflow/providers/fab/www/extensions/init_views.py @@ -22,16 +22,11 @@ from connexion import Resolver from connexion.decorators.validation import RequestBodyValidator -from connexion.exceptions import BadRequestProblem -from flask import jsonify -from starlette import status +from connexion.exceptions import BadRequestProblem, ProblemException +from flask import request -from airflow.providers.fab.www.api_connexion.exceptions import ( - BadRequest, - NotFound, - PermissionDenied, - Unauthenticated, -) +from airflow.api_connexion.exceptions import common_error_handler +from airflow.api_fastapi.app import get_auth_manager if TYPE_CHECKING: from flask import Flask @@ -39,6 +34,21 @@ log = logging.getLogger(__name__) +def init_appbuilder_views(app): + """Initialize Web UI views.""" + from airflow.models import import_all_models + from airflow.providers.fab.www import views + + import_all_models() + + appbuilder = app.appbuilder + + # Remove the session from scoped_session registry to avoid + # reusing a session with a disconnected connection + appbuilder.session.remove() + appbuilder.add_view_no_menu(views.FabIndexView()) + + class _LazyResolution: """ OpenAPI endpoint that lazily resolves the function on first use. @@ -121,28 +131,47 @@ def init_plugins(app): app.register_blueprint(blue_print["blueprint"]) -def init_error_handlers(app: Flask): - """Add custom errors handlers.""" +base_paths: list[str] = [] # contains the list of base paths that have api endpoints - def handle_bad_request(error): - response = {"error": "Bad request"} - return jsonify(response), status.HTTP_400_BAD_REQUEST - def handle_not_found(error): - response = {"error": "Not found"} - return jsonify(response), status.HTTP_404_NOT_FOUND +def init_api_error_handlers(app: Flask) -> None: + """Add error handlers for 404 and 405 errors for existing API paths.""" + from airflow.providers.fab.www import views - def handle_unauthenticated(error): - response = {"error": "User is not authenticated"} - return jsonify(response), status.HTTP_401_UNAUTHORIZED + @app.errorhandler(404) + def _handle_api_not_found(ex): + if any([request.path.startswith(p) for p in base_paths]): + # 404 errors are never handled on the blueprint level + # unless raised from a view func so actual 404 errors, + # i.e. "no route for it" defined, need to be handled + # here on the application level + return common_error_handler(ex) + else: + return views.not_found(ex) + + @app.errorhandler(405) + def _handle_method_not_allowed(ex): + if any([request.path.startswith(p) for p in base_paths]): + return common_error_handler(ex) + else: + return views.method_not_allowed(ex) + + app.register_error_handler(ProblemException, common_error_handler) + + +def init_error_handlers(app: Flask): + """Add custom errors handlers.""" + from airflow.providers.fab.www import views - def handle_denied(error): - response = {"error": "Access is denied"} - return jsonify(response), status.HTTP_403_FORBIDDEN + app.register_error_handler(500, views.show_traceback) + app.register_error_handler(404, views.not_found) - app.register_error_handler(404, handle_not_found) - app.register_error_handler(BadRequest, handle_bad_request) - app.register_error_handler(NotFound, handle_not_found) - app.register_error_handler(Unauthenticated, handle_unauthenticated) - app.register_error_handler(PermissionDenied, handle_denied) +def init_api_auth_provider(app): + """Initialize the API offered by the auth manager.""" + auth_mgr = get_auth_manager() + blueprint = auth_mgr.get_api_endpoints() + if blueprint: + base_paths.append(blueprint.url_prefix) + app.register_blueprint(blueprint) + app.extensions["csrf"].exempt(blueprint) diff --git a/providers/fab/src/airflow/providers/fab/www/views.py b/providers/fab/src/airflow/providers/fab/www/views.py index 6f6a02cf00a9a..1f1c4e70addbd 100644 --- a/providers/fab/src/airflow/providers/fab/www/views.py +++ b/providers/fab/src/airflow/providers/fab/www/views.py @@ -19,11 +19,14 @@ import sys import traceback +from urllib.parse import unquote, urljoin, urlsplit from flask import ( g, redirect, render_template, + request, + url_for, ) from flask_appbuilder import IndexView, expose @@ -32,6 +35,26 @@ from airflow.utils.net import get_hostname from airflow.version import version +# Following the release of https://github.com/python/cpython/issues/102153 in Python 3.9.17 on +# June 6, 2023, we are adding extra sanitization of the urls passed to get_safe_url method to make it works +# the same way regardless if the user uses latest Python patchlevel versions or not. This also follows +# a recommended solution by the Python core team. +# +# From: https://github.com/python/cpython/commit/d28bafa2d3e424b6fdcfd7ae7cde8e71d7177369 +# +# We recommend that users of these APIs where the values may be used anywhere +# with security implications code defensively. Do some verification within your +# code before trusting a returned component part. Does that ``scheme`` make +# sense? Is that a sensible ``path``? Is there anything strange about that +# ``hostname``? etc. +# +# C0 control and space to be stripped per WHATWG spec. +# == "".join([chr(i) for i in range(0, 0x20 + 1)]) +_WHATWG_C0_CONTROL_OR_SPACE = ( + "\x00\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c" + "\r\x0e\x0f\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f " +) + class FabIndexView(IndexView): """ @@ -41,8 +64,6 @@ class FabIndexView(IndexView): authenticated. It is impossible to redirect the user directly to the Airflow 3 UI index page before redirecting them to this page because FAB itself defines the logic redirection and does not allow external redirect. - - It is impossible to redirect the user before """ @expose("/") @@ -75,3 +96,37 @@ def show_traceback(error): ), 500, ) + + +def not_found(error): + """Show Not Found on screen for any error in the Webserver.""" + return ( + render_template( + "airflow/error.html", + hostname="", + status_code=404, + error_message="Page cannot be found.", + ), + 404, + ) + + +def get_safe_url(url): + """Given a user-supplied URL, ensure it points to our web server.""" + if not url: + return url_for("FabIndexView.index") + + # If the url contains semicolon, redirect it to homepage to avoid + # potential XSS. (Similar to https://github.com/python/cpython/pull/24297/files (bpo-42967)) + if ";" in unquote(url): + return url_for("FabIndexView.index") + + url = url.lstrip(_WHATWG_C0_CONTROL_OR_SPACE) + + host_url = urlsplit(request.host_url) + redirect_url = urlsplit(urljoin(request.host_url, url)) + if not (redirect_url.scheme in ("http", "https") and host_url.netloc == redirect_url.netloc): + return url_for("FabIndexView.index") + + # This will ensure we only redirect to the right scheme/netloc + return redirect_url.geturl() diff --git a/providers/fab/tests/unit/fab/auth_manager/api/auth/backend/test_basic_auth.py b/providers/fab/tests/unit/fab/auth_manager/api/auth/backend/test_basic_auth.py index af893b87c8b60..845e601a947e9 100644 --- a/providers/fab/tests/unit/fab/auth_manager/api/auth/backend/test_basic_auth.py +++ b/providers/fab/tests/unit/fab/auth_manager/api/auth/backend/test_basic_auth.py @@ -23,12 +23,12 @@ from flask_appbuilder.const import AUTH_LDAP from airflow.providers.fab.auth_manager.api.auth.backend.basic_auth import requires_authentication -from airflow.www import app as application +from airflow.providers.fab.www import app as application @pytest.fixture def app(): - return application.create_app(testing=True) + return application.create_app(enable_plugins=False) @pytest.fixture diff --git a/providers/fab/tests/unit/fab/auth_manager/api/auth/backend/test_session.py b/providers/fab/tests/unit/fab/auth_manager/api/auth/backend/test_session.py index ed7cd2bf45869..296edd11b7bc9 100644 --- a/providers/fab/tests/unit/fab/auth_manager/api/auth/backend/test_session.py +++ b/providers/fab/tests/unit/fab/auth_manager/api/auth/backend/test_session.py @@ -22,12 +22,12 @@ from flask import Response from airflow.providers.fab.auth_manager.api.auth.backend.session import requires_authentication -from airflow.www import app as application +from airflow.providers.fab.www import app as application @pytest.fixture def app(): - return application.create_app(testing=True) + return application.create_app(enable_plugins=False) mock_call = Mock() diff --git a/providers/fab/tests/unit/fab/auth_manager/api_endpoints/test_asset_endpoint.py b/providers/fab/tests/unit/fab/auth_manager/api_endpoints/test_asset_endpoint.py deleted file mode 100644 index 5eb3097217917..0000000000000 --- a/providers/fab/tests/unit/fab/auth_manager/api_endpoints/test_asset_endpoint.py +++ /dev/null @@ -1,325 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -from collections.abc import Generator - -import pytest -import time_machine - -from airflow.providers.fab.www.api_connexion.exceptions import EXCEPTIONS_LINK_MAP -from airflow.providers.fab.www.security import permissions -from airflow.utils import timezone -from unit.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user - -from tests_common.test_utils.db import clear_db_assets, clear_db_runs -from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS -from tests_common.test_utils.www import _check_last_log - -try: - from airflow.models.asset import AssetDagRunQueue, AssetModel -except ImportError: - if AIRFLOW_V_3_0_PLUS: - raise - else: - pass - -pytestmark = [ - pytest.mark.db_test, - pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3.0+"), -] - - -@pytest.fixture(scope="module") -def configured_app(minimal_app_for_auth_api): - app = minimal_app_for_auth_api - create_user( - app, - username="test_queued_event", - role_name="TestQueuedEvent", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_ASSET), - (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_ASSET), - ], - ) - - yield app - - delete_user(app, username="test_queued_event") - - -class TestAssetEndpoint: - default_time = "2020-06-11T18:00:00+00:00" - - @pytest.fixture(autouse=True) - def setup_attrs(self, configured_app) -> None: - self.app = configured_app - self.client = self.app.test_client() - clear_db_assets() - clear_db_runs() - - def teardown_method(self) -> None: - clear_db_assets() - clear_db_runs() - - def _create_asset(self, session): - asset_model = AssetModel( - id=1, - uri="s3://bucket/key", - extra={"foo": "bar"}, - created_at=timezone.parse(self.default_time), - updated_at=timezone.parse(self.default_time), - ) - session.add(asset_model) - session.commit() - return asset_model - - -class TestQueuedEventEndpoint(TestAssetEndpoint): - @pytest.fixture - def time_freezer(self) -> Generator: - freezer = time_machine.travel(self.default_time, tick=False) - freezer.start() - - yield - - freezer.stop() - - def _create_asset_dag_run_queues(self, dag_id, asset_id, session): - ddrq = AssetDagRunQueue(target_dag_id=dag_id, asset_id=asset_id) - session.add(ddrq) - session.commit() - return ddrq - - -class TestGetDagAssetQueuedEvent(TestQueuedEventEndpoint): - @pytest.mark.usefixtures("time_freezer") - def test_should_respond_200(self, session, create_dummy_dag): - dag, _ = create_dummy_dag() - dag_id = dag.dag_id - asset_id = self._create_asset(session).id - self._create_asset_dag_run_queues(dag_id, asset_id, session) - asset_uri = "s3://bucket/key" - - response = self.client.get( - f"/api/v1/dags/{dag_id}/assets/queuedEvent/{asset_uri}", - environ_overrides={"REMOTE_USER": "test_queued_event"}, - ) - - assert response.status_code == 200 - assert response.json == { - "created_at": self.default_time, - "uri": "s3://bucket/key", - "dag_id": "dag", - } - - def test_should_respond_404(self): - dag_id = "not_exists" - asset_uri = "not_exists" - - response = self.client.get( - f"/api/v1/dags/{dag_id}/assets/queuedEvent/{asset_uri}", - environ_overrides={"REMOTE_USER": "test_queued_event"}, - ) - - assert response.status_code == 404 - assert response.json == { - "detail": "Queue event with dag_id: `not_exists` and asset uri: `not_exists` was not found", - "status": 404, - "title": "Queue event not found", - "type": EXCEPTIONS_LINK_MAP[404], - } - - -class TestDeleteDagAssetQueuedEvent(TestAssetEndpoint): - def test_delete_should_respond_204(self, session, create_dummy_dag): - dag, _ = create_dummy_dag() - dag_id = dag.dag_id - asset_uri = "s3://bucket/key" - asset_id = self._create_asset(session).id - - ddrq = AssetDagRunQueue(target_dag_id=dag_id, asset_id=asset_id) - session.add(ddrq) - session.commit() - conn = session.query(AssetDagRunQueue).all() - assert len(conn) == 1 - - response = self.client.delete( - f"/api/v1/dags/{dag_id}/assets/queuedEvent/{asset_uri}", - environ_overrides={"REMOTE_USER": "test_queued_event"}, - ) - - assert response.status_code == 204 - conn = session.query(AssetDagRunQueue).all() - assert len(conn) == 0 - _check_last_log(session, dag_id=dag_id, event="api.delete_dag_asset_queued_event", logical_date=None) - - def test_should_respond_404(self): - dag_id = "not_exists" - asset_uri = "not_exists" - - response = self.client.delete( - f"/api/v1/dags/{dag_id}/assets/queuedEvent/{asset_uri}", - environ_overrides={"REMOTE_USER": "test_queued_event"}, - ) - - assert response.status_code == 404 - assert response.json == { - "detail": "Queue event with dag_id: `not_exists` and asset uri: `not_exists` was not found", - "status": 404, - "title": "Queue event not found", - "type": EXCEPTIONS_LINK_MAP[404], - } - - -class TestGetDagAssetQueuedEvents(TestQueuedEventEndpoint): - @pytest.mark.usefixtures("time_freezer") - def test_should_respond_200(self, session, create_dummy_dag): - dag, _ = create_dummy_dag() - dag_id = dag.dag_id - asset_id = self._create_asset(session).id - self._create_asset_dag_run_queues(dag_id, asset_id, session) - - response = self.client.get( - f"/api/v1/dags/{dag_id}/assets/queuedEvent", - environ_overrides={"REMOTE_USER": "test_queued_event"}, - ) - - assert response.status_code == 200 - assert response.json == { - "queued_events": [ - { - "created_at": self.default_time, - "uri": "s3://bucket/key", - "dag_id": "dag", - } - ], - "total_entries": 1, - } - - def test_should_respond_404(self): - dag_id = "not_exists" - - response = self.client.get( - f"/api/v1/dags/{dag_id}/assets/queuedEvent", - environ_overrides={"REMOTE_USER": "test_queued_event"}, - ) - - assert response.status_code == 404 - assert response.json == { - "detail": "Queue event with dag_id: `not_exists` was not found", - "status": 404, - "title": "Queue event not found", - "type": EXCEPTIONS_LINK_MAP[404], - } - - -class TestDeleteDagDatasetQueuedEvents(TestAssetEndpoint): - def test_should_respond_404(self): - dag_id = "not_exists" - - response = self.client.delete( - f"/api/v1/dags/{dag_id}/assets/queuedEvent", - environ_overrides={"REMOTE_USER": "test_queued_event"}, - ) - - assert response.status_code == 404 - assert response.json == { - "detail": "Queue event with dag_id: `not_exists` was not found", - "status": 404, - "title": "Queue event not found", - "type": EXCEPTIONS_LINK_MAP[404], - } - - -class TestGetDatasetQueuedEvents(TestQueuedEventEndpoint): - @pytest.mark.usefixtures("time_freezer") - def test_should_respond_200(self, session, create_dummy_dag): - dag, _ = create_dummy_dag() - dag_id = dag.dag_id - asset_id = self._create_asset(session).id - self._create_asset_dag_run_queues(dag_id, asset_id, session) - asset_uri = "s3://bucket/key" - - response = self.client.get( - f"/api/v1/assets/queuedEvent/{asset_uri}", - environ_overrides={"REMOTE_USER": "test_queued_event"}, - ) - - assert response.status_code == 200 - assert response.json == { - "queued_events": [ - { - "created_at": self.default_time, - "uri": "s3://bucket/key", - "dag_id": "dag", - } - ], - "total_entries": 1, - } - - def test_should_respond_404(self): - asset_uri = "not_exists" - - response = self.client.get( - f"/api/v1/assets/queuedEvent/{asset_uri}", - environ_overrides={"REMOTE_USER": "test_queued_event"}, - ) - - assert response.status_code == 404 - assert response.json == { - "detail": "Queue event with asset uri: `not_exists` was not found", - "status": 404, - "title": "Queue event not found", - "type": EXCEPTIONS_LINK_MAP[404], - } - - -class TestDeleteDatasetQueuedEvents(TestQueuedEventEndpoint): - def test_delete_should_respond_204(self, session, create_dummy_dag): - dag, _ = create_dummy_dag() - dag_id = dag.dag_id - asset_id = self._create_asset(session).id - self._create_asset_dag_run_queues(dag_id, asset_id, session) - asset_uri = "s3://bucket/key" - - response = self.client.delete( - f"/api/v1/assets/queuedEvent/{asset_uri}", - environ_overrides={"REMOTE_USER": "test_queued_event"}, - ) - - assert response.status_code == 204 - conn = session.query(AssetDagRunQueue).all() - assert len(conn) == 0 - _check_last_log(session, dag_id=None, event="api.delete_asset_queued_events", logical_date=None) - - def test_should_respond_404(self): - asset_uri = "not_exists" - - response = self.client.delete( - f"/api/v1/assets/queuedEvent/{asset_uri}", - environ_overrides={"REMOTE_USER": "test_queued_event"}, - ) - - assert response.status_code == 404 - assert response.json == { - "detail": "Queue event with asset uri: `not_exists` was not found", - "status": 404, - "title": "Queue event not found", - "type": EXCEPTIONS_LINK_MAP[404], - } diff --git a/providers/fab/tests/unit/fab/auth_manager/api_endpoints/test_auth.py b/providers/fab/tests/unit/fab/auth_manager/api_endpoints/test_auth.py index 4f8bc12702ff4..4814941ba154d 100644 --- a/providers/fab/tests/unit/fab/auth_manager/api_endpoints/test_auth.py +++ b/providers/fab/tests/unit/fab/auth_manager/api_endpoints/test_auth.py @@ -21,11 +21,9 @@ import pytest from flask_login import current_user -from tests_common.test_utils.api_connexion_utils import assert_401 from tests_common.test_utils.config import conf_vars from tests_common.test_utils.db import clear_db_pools from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS -from tests_common.test_utils.www import client_with_login pytestmark = [ pytest.mark.db_test, @@ -55,7 +53,7 @@ def set_attrs(self, minimal_app_for_auth_api): class TestBasicAuth(BaseTestAuth): @pytest.fixture(autouse=True, scope="class") def with_basic_auth_backend(self, minimal_app_for_auth_api): - from airflow.www.extensions.init_security import init_api_auth + from airflow.providers.fab.www.extensions.init_security import init_api_auth old_auth = getattr(minimal_app_for_auth_api, "api_auth") @@ -73,103 +71,7 @@ def test_success(self): clear_db_pools() with self.app.test_client() as test_client: - response = test_client.get("/api/v1/pools", headers={"Authorization": token}) + response = test_client.get("/auth/fab/v1/users", headers={"Authorization": token}) assert current_user.email == "test@fab.org" assert response.status_code == 200 - assert response.json == { - "pools": [ - { - "name": "default_pool", - "slots": 128, - "occupied_slots": 0, - "running_slots": 0, - "queued_slots": 0, - "scheduled_slots": 0, - "deferred_slots": 0, - "open_slots": 128, - "description": "Default pool", - "include_deferred": False, - }, - ], - "total_entries": 1, - } - - @pytest.mark.parametrize( - "token", - [ - "basic", - "basic ", - "bearer", - "test:test", - b64encode(b"test:test").decode(), - "bearer ", - "basic: ", - "basic 123", - ], - ) - def test_malformed_headers(self, token): - with self.app.test_client() as test_client: - response = test_client.get("/api/v1/pools", headers={"Authorization": token}) - assert response.status_code == 401 - assert response.headers["Content-Type"] == "application/problem+json" - assert response.headers["WWW-Authenticate"] == "Basic" - assert_401(response) - - @pytest.mark.parametrize( - "token", - [ - "basic " + b64encode(b"test").decode(), - "basic " + b64encode(b"test:").decode(), - "basic " + b64encode(b"test:123").decode(), - "basic " + b64encode(b"test test").decode(), - ], - ) - def test_invalid_auth_header(self, token): - with self.app.test_client() as test_client: - response = test_client.get("/api/v1/pools", headers={"Authorization": token}) - assert response.status_code == 401 - assert response.headers["Content-Type"] == "application/problem+json" - assert response.headers["WWW-Authenticate"] == "Basic" - assert_401(response) - - -class TestSessionWithBasicAuthFallback(BaseTestAuth): - @pytest.fixture(autouse=True, scope="class") - def with_basic_auth_backend(self, minimal_app_for_auth_api): - from airflow.www.extensions.init_security import init_api_auth - - old_auth = getattr(minimal_app_for_auth_api, "api_auth") - - try: - with conf_vars( - { - ( - "api", - "auth_backends", - ): "airflow.providers.fab.auth_manager.api.auth.backend.session,airflow.providers.fab.auth_manager.api.auth.backend.basic_auth" - } - ): - init_api_auth(minimal_app_for_auth_api) - yield - finally: - setattr(minimal_app_for_auth_api, "api_auth", old_auth) - - def test_basic_auth_fallback(self): - token = "Basic " + b64encode(b"test:test").decode() - clear_db_pools() - - # request uses session - admin_user = client_with_login(self.app, username="test", password="test") - response = admin_user.get("/api/v1/pools") - assert response.status_code == 200 - - # request uses basic auth - with self.app.test_client() as test_client: - response = test_client.get("/api/v1/pools", headers={"Authorization": token}) - assert response.status_code == 200 - - # request without session or basic auth header - with self.app.test_client() as test_client: - response = test_client.get("/api/v1/pools") - assert response.status_code == 401 diff --git a/providers/fab/tests/unit/fab/auth_manager/api_endpoints/test_cors.py b/providers/fab/tests/unit/fab/auth_manager/api_endpoints/test_cors.py deleted file mode 100644 index b8947925b1ec5..0000000000000 --- a/providers/fab/tests/unit/fab/auth_manager/api_endpoints/test_cors.py +++ /dev/null @@ -1,155 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -from base64 import b64encode - -import pytest - -from tests_common.test_utils.config import conf_vars -from tests_common.test_utils.db import clear_db_pools -from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS - -pytestmark = [ - pytest.mark.db_test, - pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3.0+"), -] - - -class BaseTestAuth: - @pytest.fixture(autouse=True) - def set_attrs(self, minimal_app_for_auth_api): - self.app = minimal_app_for_auth_api - - sm = self.app.appbuilder.sm - tester = sm.find_user(username="test") - if not tester: - role_admin = sm.find_role("Admin") - sm.add_user( - username="test", - first_name="test", - last_name="test", - email="test@fab.org", - role=role_admin, - password="test", - ) - - -class TestEmptyCors(BaseTestAuth): - @pytest.fixture(autouse=True, scope="class") - def with_basic_auth_backend(self, minimal_app_for_auth_api): - from airflow.www.extensions.init_security import init_api_auth - - old_auth = getattr(minimal_app_for_auth_api, "api_auth") - - try: - with conf_vars( - {("api", "auth_backends"): "airflow.providers.fab.auth_manager.api.auth.backend.basic_auth"} - ): - init_api_auth(minimal_app_for_auth_api) - yield - finally: - setattr(minimal_app_for_auth_api, "api_auth", old_auth) - - def test_empty_cors_headers(self): - token = "Basic " + b64encode(b"test:test").decode() - clear_db_pools() - - with self.app.test_client() as test_client: - response = test_client.get("/api/v1/pools", headers={"Authorization": token}) - assert response.status_code == 200 - assert "Access-Control-Allow-Headers" not in response.headers - assert "Access-Control-Allow-Methods" not in response.headers - assert "Access-Control-Allow-Origin" not in response.headers - - -class TestCorsOrigin(BaseTestAuth): - @pytest.fixture(autouse=True, scope="class") - def with_basic_auth_backend(self, minimal_app_for_auth_api): - from airflow.www.extensions.init_security import init_api_auth - - old_auth = getattr(minimal_app_for_auth_api, "api_auth") - - try: - with conf_vars( - { - ( - "api", - "auth_backends", - ): "airflow.providers.fab.auth_manager.api.auth.backend.basic_auth", - ("api", "access_control_allow_origins"): "http://apache.org http://example.com", - } - ): - init_api_auth(minimal_app_for_auth_api) - yield - finally: - setattr(minimal_app_for_auth_api, "api_auth", old_auth) - - def test_cors_origin_reflection(self): - token = "Basic " + b64encode(b"test:test").decode() - clear_db_pools() - - with self.app.test_client() as test_client: - response = test_client.get("/api/v1/pools", headers={"Authorization": token}) - assert response.status_code == 200 - assert response.headers["Access-Control-Allow-Origin"] == "http://apache.org" - - response = test_client.get( - "/api/v1/pools", headers={"Authorization": token, "Origin": "http://apache.org"} - ) - assert response.status_code == 200 - assert response.headers["Access-Control-Allow-Origin"] == "http://apache.org" - - response = test_client.get( - "/api/v1/pools", headers={"Authorization": token, "Origin": "http://example.com"} - ) - assert response.status_code == 200 - assert response.headers["Access-Control-Allow-Origin"] == "http://example.com" - - -class TestCorsWildcard(BaseTestAuth): - @pytest.fixture(autouse=True, scope="class") - def with_basic_auth_backend(self, minimal_app_for_auth_api): - from airflow.www.extensions.init_security import init_api_auth - - old_auth = getattr(minimal_app_for_auth_api, "api_auth") - - try: - with conf_vars( - { - ( - "api", - "auth_backends", - ): "airflow.providers.fab.auth_manager.api.auth.backend.basic_auth", - ("api", "access_control_allow_origins"): "*", - } - ): - init_api_auth(minimal_app_for_auth_api) - yield - finally: - setattr(minimal_app_for_auth_api, "api_auth", old_auth) - - def test_cors_origin_reflection(self): - token = "Basic " + b64encode(b"test:test").decode() - clear_db_pools() - - with self.app.test_client() as test_client: - response = test_client.get( - "/api/v1/pools", headers={"Authorization": token, "Origin": "http://example.com"} - ) - assert response.status_code == 200 - assert response.headers["Access-Control-Allow-Origin"] == "*" diff --git a/providers/fab/tests/unit/fab/auth_manager/api_endpoints/test_dag_endpoint.py b/providers/fab/tests/unit/fab/auth_manager/api_endpoints/test_dag_endpoint.py deleted file mode 100644 index 903038d7678e3..0000000000000 --- a/providers/fab/tests/unit/fab/auth_manager/api_endpoints/test_dag_endpoint.py +++ /dev/null @@ -1,226 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -import pendulum -import pytest - -from airflow.models import DagModel -from airflow.providers.fab.www.api_connexion.exceptions import EXCEPTIONS_LINK_MAP -from airflow.providers.fab.www.security import permissions -from airflow.utils.session import provide_session -from unit.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user - -from tests_common.test_utils.db import clear_db_dags, clear_db_runs, clear_db_serialized_dags -from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS -from tests_common.test_utils.www import _check_last_log - -pytestmark = [ - pytest.mark.db_test, - pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3.0+"), -] - - -@pytest.fixture -def current_file_token(url_safe_serializer) -> str: - return url_safe_serializer.dumps(__file__) - - -DAG_ID = "test_dag" -TASK_ID = "op1" -DAG2_ID = "test_dag2" -DAG3_ID = "test_dag3" -UTC_JSON_REPR = "UTC" if pendulum.__version__.startswith("3") else "Timezone('UTC')" - - -@pytest.fixture(scope="module") -def configured_app(minimal_app_for_auth_api): - app = minimal_app_for_auth_api - - create_user(app, username="test_granular_permissions", role_name="TestGranularDag") - app.appbuilder.sm.sync_perm_for_dag( - "TEST_DAG_1", - access_control={ - "TestGranularDag": { - permissions.RESOURCE_DAG: {permissions.ACTION_CAN_EDIT, permissions.ACTION_CAN_READ} - }, - }, - ) - app.appbuilder.sm.sync_perm_for_dag( - "TEST_DAG_1", - access_control={ - "TestGranularDag": { - permissions.RESOURCE_DAG: {permissions.ACTION_CAN_EDIT, permissions.ACTION_CAN_READ} - }, - }, - ) - - yield app - - delete_user(app, username="test_granular_permissions") - - -class TestDagEndpoint: - @staticmethod - def clean_db(): - clear_db_runs() - clear_db_dags() - clear_db_serialized_dags() - - @pytest.fixture(autouse=True) - def setup_attrs(self, configured_app) -> None: - self.clean_db() - self.app = configured_app - self.client = self.app.test_client() # type:ignore - self.dag_id = DAG_ID - self.dag2_id = DAG2_ID - self.dag3_id = DAG3_ID - - def teardown_method(self) -> None: - self.clean_db() - - @provide_session - def _create_dag_models(self, count, dag_id_prefix="TEST_DAG", is_paused=False, session=None): - for num in range(1, count + 1): - dag_model = DagModel( - dag_id=f"{dag_id_prefix}_{num}", - fileloc=f"/tmp/dag_{num}.py", - timetable_summary="2 2 * * *", - is_active=True, - is_paused=is_paused, - ) - session.add(dag_model) - - @provide_session - def _create_dag_model_for_details_endpoint(self, dag_id, session=None): - dag_model = DagModel( - dag_id=dag_id, - fileloc="/tmp/dag.py", - timetable_summary="2 2 * * *", - is_active=True, - is_paused=False, - ) - session.add(dag_model) - - @provide_session - def _create_dag_model_for_details_endpoint_with_asset_expression(self, dag_id, session=None): - dag_model = DagModel( - dag_id=dag_id, - fileloc="/tmp/dag.py", - timetable_summary="2 2 * * *", - is_active=True, - is_paused=False, - asset_expression={ - "any": [ - "s3://dag1/output_1.txt", - {"all": ["s3://dag2/output_1.txt", "s3://dag3/output_3.txt"]}, - ] - }, - ) - session.add(dag_model) - - @provide_session - def _create_deactivated_dag(self, session=None): - dag_model = DagModel( - dag_id="TEST_DAG_DELETED_1", - fileloc="/tmp/dag_del_1.py", - timetable_summary="2 2 * * *", - is_active=False, - ) - session.add(dag_model) - - -class TestGetDag(TestDagEndpoint): - def test_should_respond_200_with_granular_dag_access(self): - self._create_dag_models(1) - response = self.client.get( - "/api/v1/dags/TEST_DAG_1", environ_overrides={"REMOTE_USER": "test_granular_permissions"} - ) - assert response.status_code == 200 - - def test_should_respond_403_with_granular_access_for_different_dag(self): - self._create_dag_models(3) - response = self.client.get( - "/api/v1/dags/TEST_DAG_2", environ_overrides={"REMOTE_USER": "test_granular_permissions"} - ) - assert response.status_code == 403 - - -class TestGetDags(TestDagEndpoint): - def test_should_respond_200_with_granular_dag_access(self): - self._create_dag_models(3) - response = self.client.get( - "/api/v1/dags", environ_overrides={"REMOTE_USER": "test_granular_permissions"} - ) - assert response.status_code == 200 - assert len(response.json["dags"]) == 1 - assert response.json["dags"][0]["dag_id"] == "TEST_DAG_1" - - -class TestPatchDag(TestDagEndpoint): - @provide_session - def _create_dag_model(self, session=None): - dag_model = DagModel( - dag_id="TEST_DAG_1", fileloc="/tmp/dag_1.py", timetable_summary="2 2 * * *", is_paused=True - ) - session.add(dag_model) - return dag_model - - def test_should_respond_200_on_patch_with_granular_dag_access(self, session): - self._create_dag_models(1) - response = self.client.patch( - "/api/v1/dags/TEST_DAG_1", - json={ - "is_paused": False, - }, - environ_overrides={"REMOTE_USER": "test_granular_permissions"}, - ) - assert response.status_code == 200 - _check_last_log(session, dag_id="TEST_DAG_1", event="api.patch_dag", logical_date=None) - - def test_validation_error_raises_400(self): - patch_body = { - "ispaused": True, - } - dag_model = self._create_dag_model() - response = self.client.patch( - f"/api/v1/dags/{dag_model.dag_id}", - json=patch_body, - environ_overrides={"REMOTE_USER": "test_granular_permissions"}, - ) - assert response.status_code == 400 - assert response.json == { - "detail": "{'ispaused': ['Unknown field.']}", - "status": 400, - "title": "Bad Request", - "type": EXCEPTIONS_LINK_MAP[400], - } - - -class TestPatchDags(TestDagEndpoint): - def test_should_respond_200_with_granular_dag_access(self): - self._create_dag_models(3) - response = self.client.patch( - "api/v1/dags?dag_id_pattern=~", - json={ - "is_paused": False, - }, - environ_overrides={"REMOTE_USER": "test_granular_permissions"}, - ) - assert response.status_code == 200 - assert len(response.json["dags"]) == 1 - assert response.json["dags"][0]["dag_id"] == "TEST_DAG_1" diff --git a/providers/fab/tests/unit/fab/auth_manager/api_endpoints/test_dag_run_endpoint.py b/providers/fab/tests/unit/fab/auth_manager/api_endpoints/test_dag_run_endpoint.py deleted file mode 100644 index 10c8817fd31fe..0000000000000 --- a/providers/fab/tests/unit/fab/auth_manager/api_endpoints/test_dag_run_endpoint.py +++ /dev/null @@ -1,277 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -from datetime import timedelta - -import pytest - -from airflow.models.dag import DagModel -from airflow.models.dagrun import DagRun -from airflow.providers.fab.www.security import permissions -from airflow.sdk.definitions.param import Param -from airflow.utils import timezone -from airflow.utils.session import create_session -from airflow.utils.state import DagRunState -from unit.fab.auth_manager.api_endpoints.api_connexion_utils import ( - create_user, - delete_roles, - delete_user, -) - -from tests_common.test_utils.db import clear_db_dags, clear_db_runs, clear_db_serialized_dags -from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS - -try: - from airflow.utils.types import DagRunTriggeredByType, DagRunType -except ImportError: - if AIRFLOW_V_3_0_PLUS: - raise - else: - pass - -pytestmark = [ - pytest.mark.db_test, - pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3.0+"), -] - - -@pytest.fixture(scope="module") -def configured_app(minimal_app_for_auth_api): - app = minimal_app_for_auth_api - - create_user( - app, - username="test_no_dag_run_create_permission", - role_name="TestNoDagRunCreatePermission", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_ASSET), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_CLUSTER_ACTIVITY), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_DAG_RUN), - ], - ) - create_user( - app, - username="test_dag_view_only", - role_name="TestViewDags", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_CREATE, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_DAG_RUN), - ], - ) - create_user( - app, - username="test_view_dags", - role_name="TestViewDags", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_CREATE, permissions.RESOURCE_DAG_RUN), - ], - ) - create_user( - app, - username="test_granular_permissions", - role_name="TestGranularDag", - permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN)], - ) - app.appbuilder.sm.sync_perm_for_dag( - "TEST_DAG_ID", - access_control={ - "TestGranularDag": {permissions.ACTION_CAN_EDIT, permissions.ACTION_CAN_READ}, - "TestNoDagRunCreatePermission": {permissions.RESOURCE_DAG_RUN: {permissions.ACTION_CAN_CREATE}}, - }, - ) - - yield app - - delete_user(app, username="test_dag_view_only") - delete_user(app, username="test_view_dags") - delete_user(app, username="test_granular_permissions") - delete_user(app, username="test_no_dag_run_create_permission") - delete_roles(app) - - -class TestDagRunEndpoint: - default_time = "2020-06-11T18:00:00+00:00" - default_time_2 = "2020-06-12T18:00:00+00:00" - default_time_3 = "2020-06-13T18:00:00+00:00" - - @pytest.fixture(autouse=True) - def setup_attrs(self, configured_app) -> None: - self.app = configured_app - self.client = self.app.test_client() # type:ignore - clear_db_runs() - clear_db_serialized_dags() - clear_db_dags() - - @pytest.fixture(autouse=True) - def create_dag(self, dag_maker, setup_attrs): - with dag_maker( - "TEST_DAG_ID", schedule=None, params={"validated_number": Param(1, minimum=1, maximum=10)} - ): - pass - - dag_maker.sync_dagbag_to_db() - - def teardown_method(self) -> None: - clear_db_runs() - clear_db_dags() - clear_db_serialized_dags() - - def _create_test_dag_run(self, state=DagRunState.RUNNING, extra_dag=False, commit=True, idx_start=1): - dag_runs = [] - dags = [] - - def _v3_kwargs(date): - if AIRFLOW_V_3_0_PLUS: - return { - "data_interval": (date, date), - "logical_date": date, - "run_after": date, - "triggered_by": DagRunTriggeredByType.TEST, - } - return {"execution_date": date} - - for i in range(idx_start, idx_start + 2): - dagrun_model = DagRun( - dag_id="TEST_DAG_ID", - run_id=f"TEST_DAG_RUN_ID_{i}", - run_type=DagRunType.MANUAL, - start_date=timezone.parse(self.default_time), - external_trigger=True, - state=state, - **_v3_kwargs(timezone.parse(self.default_time) + timedelta(days=i - 1)), - ) - dagrun_model.updated_at = timezone.parse(self.default_time) - dag_runs.append(dagrun_model) - - if extra_dag: - for i in range(idx_start + 2, idx_start + 4): - dags.append(DagModel(dag_id=f"TEST_DAG_ID_{i}")) - dag_runs.append( - DagRun( - dag_id=f"TEST_DAG_ID_{i}", - run_id=f"TEST_DAG_RUN_ID_{i}", - run_type=DagRunType.MANUAL, - start_date=timezone.parse(self.default_time), - external_trigger=True, - state=state, - **_v3_kwargs(timezone.parse(self.default_time_2)), - ) - ) - if commit: - with create_session() as session: - session.add_all(dag_runs) - session.add_all(dags) - return dag_runs - - -class TestGetDagRuns(TestDagRunEndpoint): - def test_should_return_accessible_with_tilde_as_dag_id_and_dag_level_permissions(self): - self._create_test_dag_run(extra_dag=True) - expected_dag_run_ids = ["TEST_DAG_ID", "TEST_DAG_ID"] - response = self.client.get( - "api/v1/dags/~/dagRuns", environ_overrides={"REMOTE_USER": "test_granular_permissions"} - ) - assert response.status_code == 200 - dag_run_ids = [dag_run["dag_id"] for dag_run in response.json["dag_runs"]] - assert dag_run_ids == expected_dag_run_ids - - -class TestGetDagRunBatch(TestDagRunEndpoint): - def test_should_return_accessible_with_tilde_as_dag_id_and_dag_level_permissions(self): - self._create_test_dag_run(extra_dag=True) - expected_response_json_1 = { - "dag_id": "TEST_DAG_ID", - "dag_run_id": "TEST_DAG_RUN_ID_1", - "end_date": None, - "state": "running", - "logical_date": self.default_time, - "run_after": self.default_time, - "external_trigger": True, - "start_date": self.default_time, - "conf": {}, - "data_interval_end": self.default_time, - "data_interval_start": self.default_time, - "last_scheduling_decision": None, - "run_type": "manual", - "note": None, - } - expected_response_json_1.update({"triggered_by": "test"} if AIRFLOW_V_3_0_PLUS else {}) - expected_response_json_2 = { - "dag_id": "TEST_DAG_ID", - "dag_run_id": "TEST_DAG_RUN_ID_2", - "end_date": None, - "state": "running", - "logical_date": self.default_time_2, - "run_after": self.default_time_2, - "external_trigger": True, - "start_date": self.default_time, - "conf": {}, - "data_interval_end": self.default_time_2, - "data_interval_start": self.default_time_2, - "last_scheduling_decision": None, - "run_type": "manual", - "note": None, - } - expected_response_json_2.update({"triggered_by": "test"} if AIRFLOW_V_3_0_PLUS else {}) - - response = self.client.post( - "api/v1/dags/~/dagRuns/list", - json={"dag_ids": []}, - environ_overrides={"REMOTE_USER": "test_granular_permissions"}, - ) - assert response.status_code == 200 - assert response.json == { - "dag_runs": [ - expected_response_json_1, - expected_response_json_2, - ], - "total_entries": 2, - } - - -class TestPostDagRun(TestDagRunEndpoint): - def test_dagrun_trigger_with_dag_level_permissions(self): - response = self.client.post( - "api/v1/dags/TEST_DAG_ID/dagRuns", - json={"conf": {"validated_number": 1}}, - environ_overrides={"REMOTE_USER": "test_no_dag_run_create_permission"}, - ) - assert response.status_code == 200 - - @pytest.mark.parametrize( - "username", - ["test_dag_view_only", "test_view_dags", "test_granular_permissions"], - ) - def test_should_raises_403_unauthorized(self, username): - response = self.client.post( - "api/v1/dags/TEST_DAG_ID/dagRuns", - json={ - "dag_run_id": "TEST_DAG_RUN_ID_1", - "logical_date": self.default_time, - }, - environ_overrides={"REMOTE_USER": username}, - ) - assert response.status_code == 403 diff --git a/providers/fab/tests/unit/fab/auth_manager/api_endpoints/test_dag_source_endpoint.py b/providers/fab/tests/unit/fab/auth_manager/api_endpoints/test_dag_source_endpoint.py deleted file mode 100644 index fa28cb3f3b8f6..0000000000000 --- a/providers/fab/tests/unit/fab/auth_manager/api_endpoints/test_dag_source_endpoint.py +++ /dev/null @@ -1,132 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -import ast -import os - -import pytest - -from airflow.models import DagBag -from airflow.providers.fab.www.security import permissions -from unit.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user - -from tests_common.test_utils.db import ( - clear_db_dag_code, - clear_db_dags, - clear_db_serialized_dags, - parse_and_sync_to_db, -) -from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS - -pytestmark = [ - pytest.mark.db_test, - pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3.0+"), -] - - -EXAMPLE_DAG_ID = "example_bash_operator" -TEST_DAG_ID = "latest_only" -NOT_READABLE_DAG_ID = "latest_only_with_trigger" -TEST_MULTIPLE_DAGS_ID = "asset_produces_1" - - -@pytest.fixture(scope="module") -def configured_app(minimal_app_for_auth_api): - app = minimal_app_for_auth_api - create_user( - app, - username="test", - role_name="Test", - permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_CODE)], - ) - app.appbuilder.sm.sync_perm_for_dag( - TEST_DAG_ID, - access_control={"Test": [permissions.ACTION_CAN_READ]}, - ) - app.appbuilder.sm.sync_perm_for_dag( - EXAMPLE_DAG_ID, - access_control={"Test": [permissions.ACTION_CAN_READ]}, - ) - app.appbuilder.sm.sync_perm_for_dag( - TEST_MULTIPLE_DAGS_ID, - access_control={"Test": [permissions.ACTION_CAN_READ]}, - ) - - yield app - - delete_user(app, username="test") - - -class TestGetSource: - @pytest.fixture(autouse=True) - def setup_attrs(self, configured_app) -> None: - self.app = configured_app - self.client = self.app.test_client() # type:ignore - self.clear_db() - - def teardown_method(self) -> None: - self.clear_db() - - @staticmethod - def clear_db(): - clear_db_dags() - clear_db_serialized_dags() - clear_db_dag_code() - - @staticmethod - def _get_dag_file_docstring(fileloc: str) -> str | None: - with open(fileloc) as f: - file_contents = f.read() - module = ast.parse(file_contents) - docstring = ast.get_docstring(module) - return docstring - - def test_should_respond_403_not_readable(self, url_safe_serializer): - parse_and_sync_to_db(os.devnull, include_examples=True) - dagbag = DagBag(read_dags_from_db=True) - dag = dagbag.get_dag(NOT_READABLE_DAG_ID) - - response = self.client.get( - f"/api/v1/dagSources/{dag.dag_id}", - headers={"Accept": "text/plain"}, - environ_overrides={"REMOTE_USER": "test"}, - ) - read_dag = self.client.get( - f"/api/v1/dags/{NOT_READABLE_DAG_ID}", - environ_overrides={"REMOTE_USER": "test"}, - ) - assert response.status_code == 403 - assert read_dag.status_code == 403 - - def test_should_respond_403_some_dags_not_readable_in_the_file(self, url_safe_serializer): - parse_and_sync_to_db(os.devnull, include_examples=True) - dagbag = DagBag(read_dags_from_db=True) - dag = dagbag.get_dag(TEST_MULTIPLE_DAGS_ID) - - response = self.client.get( - f"/api/v1/dagSources/{dag.dag_id}", - headers={"Accept": "text/plain"}, - environ_overrides={"REMOTE_USER": "test"}, - ) - - read_dag = self.client.get( - f"/api/v1/dags/{TEST_MULTIPLE_DAGS_ID}", - environ_overrides={"REMOTE_USER": "test"}, - ) - assert response.status_code == 403 - assert read_dag.status_code == 200 diff --git a/providers/fab/tests/unit/fab/auth_manager/api_endpoints/test_dag_warning_endpoint.py b/providers/fab/tests/unit/fab/auth_manager/api_endpoints/test_dag_warning_endpoint.py deleted file mode 100644 index d7ff3cf87940c..0000000000000 --- a/providers/fab/tests/unit/fab/auth_manager/api_endpoints/test_dag_warning_endpoint.py +++ /dev/null @@ -1,84 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -import pytest - -from airflow.models.dag import DagModel -from airflow.models.dagwarning import DagWarning -from airflow.providers.fab.www.security import permissions -from airflow.utils.session import create_session -from unit.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user - -from tests_common.test_utils.db import clear_db_dag_warnings, clear_db_dags -from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS - -pytestmark = [ - pytest.mark.db_test, - pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3.0+"), -] - - -@pytest.fixture(scope="module") -def configured_app(minimal_app_for_auth_api): - app = minimal_app_for_auth_api - create_user( - app, # type:ignore - username="test_with_dag2_read", - role_name="TestWithDag2Read", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_WARNING), - (permissions.ACTION_CAN_READ, f"{permissions.RESOURCE_DAG_PREFIX}dag2"), - ], - ) - - yield app - - delete_user(app, username="test_with_dag2_read") - - -class TestBaseDagWarning: - timestamp = "2020-06-10T12:00" - - @pytest.fixture(autouse=True) - def setup_attrs(self, configured_app) -> None: - self.app = configured_app - self.client = self.app.test_client() # type:ignore - - def teardown_method(self) -> None: - clear_db_dag_warnings() - clear_db_dags() - - -class TestGetDagWarningEndpoint(TestBaseDagWarning): - def setup_class(self): - clear_db_dag_warnings() - clear_db_dags() - - def setup_method(self): - with create_session() as session: - session.add(DagModel(dag_id="dag1")) - session.add(DagWarning("dag1", "non-existent pool", "test message")) - session.commit() - - def test_should_raise_403_forbidden_when_user_has_no_dag_read_permission(self): - response = self.client.get( - "/api/v1/dagWarnings", - environ_overrides={"REMOTE_USER": "test_with_dag2_read"}, - query_string={"dag_id": "dag1"}, - ) - assert response.status_code == 403 diff --git a/providers/fab/tests/unit/fab/auth_manager/api_endpoints/test_event_log_endpoint.py b/providers/fab/tests/unit/fab/auth_manager/api_endpoints/test_event_log_endpoint.py deleted file mode 100644 index b890d92c13b5f..0000000000000 --- a/providers/fab/tests/unit/fab/auth_manager/api_endpoints/test_event_log_endpoint.py +++ /dev/null @@ -1,151 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -import pytest - -from airflow.models import Log -from airflow.providers.fab.www.security import permissions -from airflow.utils import timezone -from unit.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user - -from tests_common.test_utils.db import clear_db_logs -from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS - -pytestmark = [ - pytest.mark.db_test, - pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3.0+"), -] - - -@pytest.fixture(scope="module") -def configured_app(minimal_app_for_auth_api): - app = minimal_app_for_auth_api - create_user( - app, - username="test_granular", - role_name="TestGranular", - permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_AUDIT_LOG)], - ) - app.appbuilder.sm.sync_perm_for_dag( - "TEST_DAG_ID_1", - access_control={"TestGranular": [permissions.ACTION_CAN_READ]}, - ) - app.appbuilder.sm.sync_perm_for_dag( - "TEST_DAG_ID_2", - access_control={"TestGranular": [permissions.ACTION_CAN_READ]}, - ) - - yield app - - delete_user(app, username="test_granular") - - -@pytest.fixture -def task_instance(session, create_task_instance, request): - return create_task_instance( - session=session, - dag_id="TEST_DAG_ID", - task_id="TEST_TASK_ID", - run_id="TEST_RUN_ID", - logical_date=request.instance.default_time, - ) - - -@pytest.fixture -def create_log_model(create_task_instance, task_instance, session, request): - def maker(event, when, **kwargs): - log_model = Log( - event=event, - task_instance=task_instance, - **kwargs, - ) - log_model.dttm = when - - session.add(log_model) - session.flush() - return log_model - - return maker - - -class TestEventLogEndpoint: - @pytest.fixture(autouse=True) - def setup_attrs(self, configured_app) -> None: - self.app = configured_app - self.client = self.app.test_client() # type:ignore - clear_db_logs() - self.default_time = timezone.parse("2020-06-10T20:00:00+00:00") - self.default_time_2 = timezone.parse("2020-06-11T07:00:00+00:00") - - def teardown_method(self) -> None: - clear_db_logs() - - -class TestGetEventLogs(TestEventLogEndpoint): - def test_should_filter_eventlogs_by_allowed_attributes(self, create_log_model, session): - eventlog1 = create_log_model( - event="TEST_EVENT_1", - dag_id="TEST_DAG_ID_1", - task_id="TEST_TASK_ID_1", - owner="TEST_OWNER_1", - when=self.default_time, - ) - eventlog2 = create_log_model( - event="TEST_EVENT_2", - dag_id="TEST_DAG_ID_2", - task_id="TEST_TASK_ID_2", - owner="TEST_OWNER_2", - when=self.default_time_2, - ) - session.add_all([eventlog1, eventlog2]) - session.commit() - for attr in ["dag_id", "task_id", "owner", "event"]: - attr_value = f"TEST_{attr}_1".upper() - response = self.client.get( - f"/api/v1/eventLogs?{attr}={attr_value}", environ_overrides={"REMOTE_USER": "test_granular"} - ) - assert response.status_code == 200 - assert response.json["total_entries"] == 1 - assert len(response.json["event_logs"]) == 1 - assert response.json["event_logs"][0][attr] == attr_value - - def test_should_filter_eventlogs_by_included_events(self, create_log_model): - for event in ["TEST_EVENT_1", "TEST_EVENT_2", "cli_scheduler"]: - create_log_model(event=event, when=self.default_time) - response = self.client.get( - "/api/v1/eventLogs?included_events=TEST_EVENT_1,TEST_EVENT_2", - environ_overrides={"REMOTE_USER": "test_granular"}, - ) - assert response.status_code == 200 - response_data = response.json - assert len(response_data["event_logs"]) == 2 - assert response_data["total_entries"] == 2 - assert {"TEST_EVENT_1", "TEST_EVENT_2"} == {x["event"] for x in response_data["event_logs"]} - - def test_should_filter_eventlogs_by_excluded_events(self, create_log_model): - for event in ["TEST_EVENT_1", "TEST_EVENT_2", "cli_scheduler"]: - create_log_model(event=event, when=self.default_time) - response = self.client.get( - "/api/v1/eventLogs?excluded_events=TEST_EVENT_1,TEST_EVENT_2", - environ_overrides={"REMOTE_USER": "test_granular"}, - ) - assert response.status_code == 200 - response_data = response.json - assert len(response_data["event_logs"]) == 1 - assert response_data["total_entries"] == 1 - assert {"cli_scheduler"} == {x["event"] for x in response_data["event_logs"]} diff --git a/providers/fab/tests/unit/fab/auth_manager/api_endpoints/test_import_error_endpoint.py b/providers/fab/tests/unit/fab/auth_manager/api_endpoints/test_import_error_endpoint.py deleted file mode 100644 index 0028d10618bb7..0000000000000 --- a/providers/fab/tests/unit/fab/auth_manager/api_endpoints/test_import_error_endpoint.py +++ /dev/null @@ -1,244 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -import pytest - -from airflow.dag_processing.bundles.manager import DagBundlesManager -from airflow.models.dag import DagModel -from airflow.providers.fab.www.security import permissions -from airflow.utils import timezone -from unit.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user - -from tests_common.test_utils.compat import ParseImportError -from tests_common.test_utils.db import clear_db_dags, clear_db_import_errors -from tests_common.test_utils.permissions import _resource_name -from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS - -pytestmark = [ - pytest.mark.db_test, - pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3.0+"), -] - -TEST_DAG_IDS = ["test_dag", "test_dag2"] -BUNDLE_NAME = "testing" - - -@pytest.fixture(scope="module") -def configured_app(minimal_app_for_auth_api): - app = minimal_app_for_auth_api - create_user( - app, - username="test_single_dag", - role_name="TestSingleDAG", - permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_IMPORT_ERROR)], - ) - # For some reason, DAG level permissions are not synced when in the above list of perms, - # so do it manually here: - app.appbuilder.sm.bulk_sync_roles( - [ - { - "role": "TestSingleDAG", - "perms": [ - ( - permissions.ACTION_CAN_READ, - _resource_name(TEST_DAG_IDS[0], permissions.RESOURCE_DAG), - ) - ], - } - ] - ) - yield app - - delete_user(app, username="test_single_dag") - - -class TestBaseImportError: - timestamp = "2020-06-10T12:00" - - @pytest.fixture(autouse=True) - def setup_attrs(self, configured_app) -> None: - self.app = configured_app - self.client = self.app.test_client() # type:ignore - - clear_db_import_errors() - clear_db_dags() - - def teardown_method(self) -> None: - clear_db_import_errors() - clear_db_dags() - - @staticmethod - def _normalize_import_errors(import_errors): - for i, import_error in enumerate(import_errors, 1): - import_error["import_error_id"] = i - - -class TestGetImportErrorEndpoint(TestBaseImportError): - def test_should_raise_403_forbidden_without_dag_read(self, configure_testing_dag_bundle, session): - with configure_testing_dag_bundle("/tmp"): - DagBundlesManager().sync_bundles_to_db() - import_error = ParseImportError( - filename="Lorem_ipsum.py", - stacktrace="Lorem ipsum", - timestamp=timezone.parse(self.timestamp, timezone="UTC"), - bundle_name=BUNDLE_NAME, - ) - session.add(import_error) - session.commit() - - response = self.client.get( - f"/api/v1/importErrors/{import_error.id}", environ_overrides={"REMOTE_USER": "test_single_dag"} - ) - - assert response.status_code == 403 - - def test_should_return_200_with_single_dag_read(self, session, configure_testing_dag_bundle): - with configure_testing_dag_bundle("/tmp"): - DagBundlesManager().sync_bundles_to_db() - dag_model = DagModel(dag_id=TEST_DAG_IDS[0], fileloc="Lorem_ipsum.py", bundle_name=BUNDLE_NAME) - session.add(dag_model) - import_error = ParseImportError( - filename="Lorem_ipsum.py", - stacktrace="Lorem ipsum", - timestamp=timezone.parse(self.timestamp, timezone="UTC"), - bundle_name=BUNDLE_NAME, - ) - session.add(import_error) - session.commit() - - response = self.client.get( - f"/api/v1/importErrors/{import_error.id}", environ_overrides={"REMOTE_USER": "test_single_dag"} - ) - - assert response.status_code == 200 - response_data = response.json - response_data["import_error_id"] = 1 - assert response_data == { - "filename": "Lorem_ipsum.py", - "bundle_name": BUNDLE_NAME, - "import_error_id": 1, - "stack_trace": "Lorem ipsum", - "timestamp": "2020-06-10T12:00:00+00:00", - } - - def test_should_return_200_redacted_with_single_dag_read_in_dagfile( - self, configure_testing_dag_bundle, session - ): - with configure_testing_dag_bundle("/tmp"): - DagBundlesManager().sync_bundles_to_db() - for dag_id in TEST_DAG_IDS: - dag_model = DagModel(dag_id=dag_id, fileloc="Lorem_ipsum.py", bundle_name=BUNDLE_NAME) - session.add(dag_model) - import_error = ParseImportError( - filename="Lorem_ipsum.py", - stacktrace="Lorem ipsum", - timestamp=timezone.parse(self.timestamp, timezone="UTC"), - bundle_name=BUNDLE_NAME, - ) - session.add(import_error) - session.commit() - - response = self.client.get( - f"/api/v1/importErrors/{import_error.id}", environ_overrides={"REMOTE_USER": "test_single_dag"} - ) - - assert response.status_code == 200 - response_data = response.json - response_data["import_error_id"] = 1 - assert response_data == { - "filename": "Lorem_ipsum.py", - "bundle_name": BUNDLE_NAME, - "import_error_id": 1, - "stack_trace": "REDACTED - you do not have read permission on all DAGs in the file", - "timestamp": "2020-06-10T12:00:00+00:00", - } - - -class TestGetImportErrorsEndpoint(TestBaseImportError): - def test_get_import_errors_single_dag(self, configure_testing_dag_bundle, session): - with configure_testing_dag_bundle("/tmp"): - DagBundlesManager().sync_bundles_to_db() - for dag_id in TEST_DAG_IDS: - fake_filename = f"/tmp/{dag_id}.py" - dag_model = DagModel(dag_id=dag_id, fileloc=fake_filename, bundle_name=BUNDLE_NAME) - session.add(dag_model) - importerror = ParseImportError( - filename=fake_filename, - bundle_name=BUNDLE_NAME, - stacktrace="Lorem ipsum", - timestamp=timezone.parse(self.timestamp, timezone="UTC"), - ) - session.add(importerror) - session.commit() - - response = self.client.get( - "/api/v1/importErrors", environ_overrides={"REMOTE_USER": "test_single_dag"} - ) - - assert response.status_code == 200 - response_data = response.json - self._normalize_import_errors(response_data["import_errors"]) - assert response_data == { - "import_errors": [ - { - "filename": "/tmp/test_dag.py", - "bundle_name": BUNDLE_NAME, - "import_error_id": 1, - "stack_trace": "Lorem ipsum", - "timestamp": "2020-06-10T12:00:00+00:00", - }, - ], - "total_entries": 1, - } - - def test_get_import_errors_single_dag_in_dagfile(self, configure_testing_dag_bundle, session): - with configure_testing_dag_bundle("/tmp"): - DagBundlesManager().sync_bundles_to_db() - for dag_id in TEST_DAG_IDS: - fake_filename = "/tmp/all_in_one.py" - dag_model = DagModel(dag_id=dag_id, fileloc=fake_filename, bundle_name="testing") - session.add(dag_model) - - importerror = ParseImportError( - filename="/tmp/all_in_one.py", - bundle_name=BUNDLE_NAME, - stacktrace="Lorem ipsum", - timestamp=timezone.parse(self.timestamp, timezone="UTC"), - ) - session.add(importerror) - session.commit() - - response = self.client.get( - "/api/v1/importErrors", environ_overrides={"REMOTE_USER": "test_single_dag"} - ) - - assert response.status_code == 200 - response_data = response.json - self._normalize_import_errors(response_data["import_errors"]) - assert response_data == { - "import_errors": [ - { - "filename": "/tmp/all_in_one.py", - "bundle_name": BUNDLE_NAME, - "import_error_id": 1, - "stack_trace": "REDACTED - you do not have read permission on all DAGs in the file", - "timestamp": "2020-06-10T12:00:00+00:00", - }, - ], - "total_entries": 1, - } diff --git a/providers/fab/tests/unit/fab/auth_manager/api_endpoints/test_task_instance_endpoint.py b/providers/fab/tests/unit/fab/auth_manager/api_endpoints/test_task_instance_endpoint.py deleted file mode 100644 index b68b370e1b363..0000000000000 --- a/providers/fab/tests/unit/fab/auth_manager/api_endpoints/test_task_instance_endpoint.py +++ /dev/null @@ -1,428 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -import datetime as dt -import urllib - -import pytest - -from airflow.models import DagRun, TaskInstance -from airflow.providers.fab.www.api_connexion.exceptions import EXCEPTIONS_LINK_MAP -from airflow.providers.fab.www.security import permissions -from airflow.utils.session import provide_session -from airflow.utils.state import State -from airflow.utils.timezone import datetime -from airflow.utils.types import DagRunType -from unit.fab.auth_manager.api_endpoints.api_connexion_utils import ( - create_user, - delete_roles, - delete_user, -) - -from tests_common.test_utils.db import clear_db_runs, clear_rendered_ti_fields -from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS - -pytestmark = [ - pytest.mark.db_test, - pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3.0+"), -] - -DEFAULT_DATETIME_1 = datetime(2020, 1, 1) -DEFAULT_DATETIME_STR_1 = "2020-01-01T00:00:00+00:00" -DEFAULT_DATETIME_STR_2 = "2020-01-02T00:00:00+00:00" - -QUOTED_DEFAULT_DATETIME_STR_1 = urllib.parse.quote(DEFAULT_DATETIME_STR_1) -QUOTED_DEFAULT_DATETIME_STR_2 = urllib.parse.quote(DEFAULT_DATETIME_STR_2) - - -@pytest.fixture(scope="module") -def configured_app(minimal_app_for_auth_api): - app = minimal_app_for_auth_api - create_user( - app, - username="test_dag_read_only", - role_name="TestDagReadOnly", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_TASK_INSTANCE), - ], - ) - create_user( - app, - username="test_task_read_only", - role_name="TestTaskReadOnly", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), - ], - ) - create_user( - app, - username="test_read_only_one_dag", - role_name="TestReadOnlyOneDag", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), - ], - ) - # For some reason, "DAG:example_python_operator" is not synced when in the above list of perms, - # so do it manually here: - app.appbuilder.sm.bulk_sync_roles( - [ - { - "role": "TestReadOnlyOneDag", - "perms": [(permissions.ACTION_CAN_READ, "DAG:example_python_operator")], - } - ] - ) - - yield app - - delete_user(app, username="test_dag_read_only") - delete_user(app, username="test_task_read_only") - delete_user(app, username="test_read_only_one_dag") - delete_roles(app) - - -class TestTaskInstanceEndpoint: - @pytest.fixture(autouse=True) - def setup_attrs(self, configured_app, dagbag) -> None: - self.default_time = DEFAULT_DATETIME_1 - self.ti_init = { - "logical_date": self.default_time, - "state": State.RUNNING, - } - self.ti_extras = { - "start_date": self.default_time + dt.timedelta(days=1), - "end_date": self.default_time + dt.timedelta(days=2), - "pid": 100, - "duration": 10000, - "pool": "default_pool", - "queue": "default_queue", - "job_id": 0, - } - self.app = configured_app - self.client = self.app.test_client() # type:ignore - clear_db_runs() - clear_rendered_ti_fields() - self.dagbag = dagbag - - def create_task_instances( - self, - session, - dag_id: str = "example_python_operator", - update_extras: bool = True, - task_instances=None, - dag_run_state=State.RUNNING, - with_ti_history=False, - ): - """Method to create task instances using kwargs and default arguments""" - - dag = self.dagbag.get_dag(dag_id) - tasks = dag.tasks - counter = len(tasks) - if task_instances is not None: - counter = min(len(task_instances), counter) - - run_id = "TEST_DAG_RUN_ID" - logical_date = self.ti_init.pop("logical_date", self.default_time) - dr = None - - tis = [] - for i in range(counter): - if task_instances is None: - pass - elif update_extras: - self.ti_extras.update(task_instances[i]) - else: - self.ti_init.update(task_instances[i]) - - if "logical_date" in self.ti_init: - run_id = f"TEST_DAG_RUN_ID_{i}" - logical_date = self.ti_init.pop("logical_date") - dr = None - - if not dr: - dr = DagRun( - run_id=run_id, - dag_id=dag_id, - logical_date=logical_date, - data_interval=(logical_date, logical_date), - run_after=logical_date, - run_type=DagRunType.MANUAL, - state=dag_run_state, - ) - session.add(dr) - ti = TaskInstance(task=tasks[i], **self.ti_init) - session.add(ti) - ti.dag_run = dr - ti.note = "placeholder-note" - - for key, value in self.ti_extras.items(): - setattr(ti, key, value) - tis.append(ti) - - session.commit() - if with_ti_history: - for ti in tis: - ti.try_number = 1 - session.merge(ti) - session.commit() - dag.clear() - for ti in tis: - ti.try_number = 2 - ti.queue = "default_queue" - session.merge(ti) - session.commit() - return tis - - -class TestGetTaskInstance(TestTaskInstanceEndpoint): - def setup_method(self): - clear_db_runs() - - def teardown_method(self): - clear_db_runs() - - @pytest.mark.parametrize("username", ["test_dag_read_only", "test_task_read_only"]) - @provide_session - def test_should_respond_200(self, username, session): - self.create_task_instances(session) - # Update ti and set operator to None to - # test that operator field is nullable. - # This prevents issue when users upgrade to 2.0+ - # from 1.10.x - # https://github.com/apache/airflow/issues/14421 - session.query(TaskInstance).update({TaskInstance.operator: None}, synchronize_session="fetch") - session.commit() - response = self.client.get( - "/api/v1/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context", - environ_overrides={"REMOTE_USER": username}, - ) - assert response.status_code == 200 - - -class TestGetTaskInstances(TestTaskInstanceEndpoint): - @pytest.mark.parametrize( - "task_instances, user, expected_ti", - [ - pytest.param( - { - "example_python_operator": 2, - "example_skip_dag": 1, - }, - "test_read_only_one_dag", - 2, - ), - pytest.param( - { - "example_python_operator": 1, - "example_skip_dag": 2, - }, - "test_read_only_one_dag", - 1, - ), - ], - ) - def test_return_TI_only_from_readable_dags(self, task_instances, user, expected_ti, session): - for dag_id in task_instances: - self.create_task_instances( - session, - task_instances=[ - {"logical_date": DEFAULT_DATETIME_1 + dt.timedelta(days=i)} - for i in range(task_instances[dag_id]) - ], - dag_id=dag_id, - ) - response = self.client.get( - "/api/v1/dags/~/dagRuns/~/taskInstances", environ_overrides={"REMOTE_USER": user} - ) - assert response.status_code == 200 - assert response.json["total_entries"] == expected_ti - assert len(response.json["task_instances"]) == expected_ti - - -class TestGetTaskInstancesBatch(TestTaskInstanceEndpoint): - @pytest.mark.parametrize( - "task_instances, update_extras, payload, expected_ti_count, username", - [ - pytest.param( - [ - {"pool": "test_pool_1"}, - {"pool": "test_pool_2"}, - {"pool": "test_pool_3"}, - ], - True, - {"pool": ["test_pool_1", "test_pool_2"]}, - 2, - "test_dag_read_only", - id="test pool filter", - ), - pytest.param( - [ - {"state": State.RUNNING}, - {"state": State.QUEUED}, - {"state": State.SUCCESS}, - {"state": State.NONE}, - ], - False, - {"state": ["running", "queued", "none"]}, - 3, - "test_task_read_only", - id="test state filter", - ), - pytest.param( - [ - {"state": State.NONE}, - {"state": State.NONE}, - {"state": State.NONE}, - {"state": State.NONE}, - ], - False, - {}, - 4, - "test_task_read_only", - id="test dag with null states", - ), - pytest.param( - [ - {"end_date": DEFAULT_DATETIME_1}, - {"end_date": DEFAULT_DATETIME_1 + dt.timedelta(days=1)}, - {"end_date": DEFAULT_DATETIME_1 + dt.timedelta(days=2)}, - ], - True, - { - "end_date_gte": DEFAULT_DATETIME_STR_1, - "end_date_lte": DEFAULT_DATETIME_STR_2, - }, - 2, - "test_task_read_only", - id="test end date filter", - ), - pytest.param( - [ - {"start_date": DEFAULT_DATETIME_1}, - {"start_date": DEFAULT_DATETIME_1 + dt.timedelta(days=1)}, - {"start_date": DEFAULT_DATETIME_1 + dt.timedelta(days=2)}, - ], - True, - { - "start_date_gte": DEFAULT_DATETIME_STR_1, - "start_date_lte": DEFAULT_DATETIME_STR_2, - }, - 2, - "test_dag_read_only", - id="test start date filter", - ), - ], - ) - def test_should_respond_200( - self, task_instances, update_extras, payload, expected_ti_count, username, session - ): - self.create_task_instances( - session, - update_extras=update_extras, - task_instances=task_instances, - ) - response = self.client.post( - "/api/v1/dags/~/dagRuns/~/taskInstances/list", - environ_overrides={"REMOTE_USER": username}, - json=payload, - ) - assert response.status_code == 200, response.json - assert expected_ti_count == response.json["total_entries"] - assert expected_ti_count == len(response.json["task_instances"]) - - def test_returns_403_forbidden_when_user_has_access_to_only_some_dags(self, session): - self.create_task_instances(session=session) - self.create_task_instances(session=session, dag_id="example_skip_dag") - payload = {"dag_ids": ["example_python_operator", "example_skip_dag"]} - - response = self.client.post( - "/api/v1/dags/~/dagRuns/~/taskInstances/list", - environ_overrides={"REMOTE_USER": "test_read_only_one_dag"}, - json=payload, - ) - assert response.status_code == 403 - assert response.json == { - "detail": "User not allowed to access some of these DAGs: ['example_python_operator', 'example_skip_dag']", - "status": 403, - "title": "Forbidden", - "type": EXCEPTIONS_LINK_MAP[403], - } - - -class TestPostSetTaskInstanceState(TestTaskInstanceEndpoint): - @pytest.mark.parametrize("username", ["test_dag_read_only", "test_task_read_only"]) - def test_should_raise_403_forbidden(self, username): - response = self.client.post( - "/api/v1/dags/example_python_operator/updateTaskInstancesState", - environ_overrides={"REMOTE_USER": username}, - json={ - "dry_run": True, - "task_id": "print_the_context", - "logical_date": DEFAULT_DATETIME_1.isoformat(), - "include_upstream": True, - "include_downstream": True, - "include_future": True, - "include_past": True, - "new_state": "failed", - }, - ) - assert response.status_code == 403 - - -class TestPatchTaskInstance(TestTaskInstanceEndpoint): - ENDPOINT_URL = ( - "/api/v1/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context" - ) - - @pytest.mark.parametrize("username", ["test_dag_read_only", "test_task_read_only"]) - def test_should_raise_403_forbidden(self, username): - response = self.client.patch( - self.ENDPOINT_URL, - environ_overrides={"REMOTE_USER": username}, - json={ - "dry_run": True, - "new_state": "failed", - }, - ) - assert response.status_code == 403 - - -class TestGetTaskInstanceTry(TestTaskInstanceEndpoint): - def setup_method(self): - clear_db_runs() - - def teardown_method(self): - clear_db_runs() - - @pytest.mark.parametrize("username", ["test_dag_read_only", "test_task_read_only"]) - @provide_session - def test_should_respond_200(self, username, session): - self.create_task_instances(session, task_instances=[{"state": State.SUCCESS}], with_ti_history=True) - - response = self.client.get( - "/api/v1/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context/tries/1", - environ_overrides={"REMOTE_USER": username}, - ) - assert response.status_code == 200 diff --git a/providers/fab/tests/unit/fab/auth_manager/api_endpoints/test_variable_endpoint.py b/providers/fab/tests/unit/fab/auth_manager/api_endpoints/test_variable_endpoint.py deleted file mode 100644 index 4ee4db1230cb7..0000000000000 --- a/providers/fab/tests/unit/fab/auth_manager/api_endpoints/test_variable_endpoint.py +++ /dev/null @@ -1,88 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -import pytest - -from airflow.models import Variable -from airflow.providers.fab.www.security import permissions -from unit.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user - -from tests_common.test_utils.db import clear_db_variables -from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS - -pytestmark = [ - pytest.mark.db_test, - pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3.0+"), -] - - -@pytest.fixture(scope="module") -def configured_app(minimal_app_for_auth_api): - app = minimal_app_for_auth_api - - create_user( - app, - username="test_read_only", - role_name="TestReadOnly", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_VARIABLE), - ], - ) - create_user( - app, - username="test_delete_only", - role_name="TestDeleteOnly", - permissions=[ - (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_VARIABLE), - ], - ) - - yield app - - delete_user(app, username="test_read_only") - delete_user(app, username="test_delete_only") - - -class TestVariableEndpoint: - @pytest.fixture(autouse=True) - def setup_method(self, configured_app) -> None: - self.app = configured_app - self.client = self.app.test_client() # type:ignore - clear_db_variables() - - def teardown_method(self) -> None: - clear_db_variables() - - -class TestGetVariable(TestVariableEndpoint): - @pytest.mark.parametrize( - "user, expected_status_code", - [ - ("test_read_only", 200), - ("test_delete_only", 403), - ], - ) - def test_read_variable(self, user, expected_status_code): - expected_value = '{"foo": 1}' - Variable.set("TEST_VARIABLE_KEY", expected_value) - response = self.client.get( - "/api/v1/variables/TEST_VARIABLE_KEY", environ_overrides={"REMOTE_USER": user} - ) - assert response.status_code == expected_status_code - if expected_status_code == 200: - assert response.json == {"key": "TEST_VARIABLE_KEY", "value": expected_value, "description": None} diff --git a/providers/fab/tests/unit/fab/auth_manager/api_endpoints/test_xcom_endpoint.py b/providers/fab/tests/unit/fab/auth_manager/api_endpoints/test_xcom_endpoint.py deleted file mode 100644 index 46b558adfbf43..0000000000000 --- a/providers/fab/tests/unit/fab/auth_manager/api_endpoints/test_xcom_endpoint.py +++ /dev/null @@ -1,282 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -from datetime import timedelta - -import pytest - -from airflow.models.dag import DagModel -from airflow.models.dagrun import DagRun -from airflow.models.taskinstance import TaskInstance -from airflow.models.xcom import BaseXCom, XCom -from airflow.providers.fab.www.security import permissions -from airflow.providers.standard.operators.empty import EmptyOperator -from airflow.utils import timezone -from airflow.utils.session import create_session -from airflow.utils.types import DagRunType -from unit.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user - -from tests_common.test_utils.db import clear_db_dags, clear_db_runs, clear_db_xcom -from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS - -pytestmark = [ - pytest.mark.db_test, - pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Test requires Airflow 3.0+"), -] - - -class CustomXCom(BaseXCom): - @classmethod - def deserialize_value(cls, xcom: XCom): - return f"real deserialized {super().deserialize_value(xcom)}" - - def orm_deserialize_value(self): - return f"orm deserialized {super().orm_deserialize_value()}" - - -@pytest.fixture(scope="module") -def configured_app(minimal_app_for_auth_api): - app = minimal_app_for_auth_api - - create_user( - app, - username="test_granular_permissions", - role_name="TestGranularDag", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_XCOM), - ], - ) - app.appbuilder.sm.sync_perm_for_dag( - "test-dag-id-1", - access_control={"TestGranularDag": [permissions.ACTION_CAN_EDIT, permissions.ACTION_CAN_READ]}, - ) - - yield app - - delete_user(app, username="test_granular_permissions") - - -def _compare_xcom_collections(collection1: dict, collection_2: dict): - assert collection1.get("total_entries") == collection_2.get("total_entries") - - def sort_key(record): - return ( - ( - record.get("dag_id"), - record.get("task_id"), - record.get("logical_date"), - record.get("map_index"), - record.get("key"), - ) - if AIRFLOW_V_3_0_PLUS - else ( - record.get("dag_id"), - record.get("task_id"), - record.get("execution_date"), - record.get("map_index"), - record.get("key"), - ) - ) - - assert sorted(collection1.get("xcom_entries", []), key=sort_key) == sorted( - collection_2.get("xcom_entries", []), key=sort_key - ) - - -class TestXComEndpoint: - @staticmethod - def clean_db(): - clear_db_dags() - clear_db_runs() - clear_db_xcom() - - @pytest.fixture(autouse=True) - def setup_attrs(self, configured_app) -> None: - """ - Setup For XCom endpoint TC - """ - self.app = configured_app - self.client = self.app.test_client() # type:ignore - # clear existing xcoms - self.clean_db() - - def teardown_method(self) -> None: - """ - Clear Hanging XComs - """ - self.clean_db() - - -class TestGetXComEntries(TestXComEndpoint): - def test_should_respond_200_with_tilde_and_granular_dag_access(self): - dag_id_1 = "test-dag-id-1" - task_id_1 = "test-task-id-1" - logical_date = "2005-04-02T00:00:00+00:00" - logical_date_parsed = timezone.parse(logical_date) - run_after = "2005-04-02T00:00:00+00:00" - run_after_parsed = timezone.parse(run_after) - dag_run_id_1 = DagRun.generate_run_id( - run_type=DagRunType.MANUAL, - logical_date=logical_date_parsed, - run_after=run_after_parsed, - ) - self._create_xcom_entries(dag_id_1, dag_run_id_1, logical_date_parsed, task_id_1) - - dag_id_2 = "test-dag-id-2" - task_id_2 = "test-task-id-2" - run_id_2 = DagRun.generate_run_id( - run_type=DagRunType.MANUAL, - logical_date=logical_date_parsed, - run_after=run_after_parsed, - ) - self._create_xcom_entries(dag_id_2, run_id_2, logical_date_parsed, task_id_2) - self._create_invalid_xcom_entries(logical_date_parsed) - response = self.client.get( - "/api/v1/dags/~/dagRuns/~/taskInstances/~/xcomEntries", - environ_overrides={"REMOTE_USER": "test_granular_permissions"}, - ) - - assert response.status_code == 200 - response_data = response.json - for xcom_entry in response_data["xcom_entries"]: - xcom_entry["timestamp"] = "TIMESTAMP" - date_key = "logical_date" if AIRFLOW_V_3_0_PLUS else "execution_date" - _compare_xcom_collections( - response_data, - { - "xcom_entries": [ - { - "dag_id": dag_id_1, - date_key: logical_date, - "key": "test-xcom-key-1", - "task_id": task_id_1, - "timestamp": "TIMESTAMP", - "map_index": -1, - }, - { - "dag_id": dag_id_1, - date_key: logical_date, - "key": "test-xcom-key-2", - "task_id": task_id_1, - "timestamp": "TIMESTAMP", - "map_index": -1, - }, - ], - "total_entries": 2, - }, - ) - - def _create_xcom_entries(self, dag_id, run_id, logical_date, task_id, mapped_ti=False): - with create_session() as session: - dag = DagModel(dag_id=dag_id) - session.add(dag) - if AIRFLOW_V_3_0_PLUS: - dagrun = DagRun( - dag_id=dag_id, - run_id=run_id, - logical_date=logical_date, - data_interval=(logical_date, logical_date), - run_after=logical_date, - start_date=logical_date, - run_type=DagRunType.MANUAL, - ) - else: - dagrun = DagRun( - dag_id=dag_id, - run_id=run_id, - execution_date=logical_date, - start_date=logical_date, - run_type=DagRunType.MANUAL, - ) - session.add(dagrun) - if mapped_ti: - for i in [0, 1]: - ti = TaskInstance(EmptyOperator(task_id=task_id), run_id=run_id, map_index=i) - ti.dag_id = dag_id - session.add(ti) - else: - ti = TaskInstance(EmptyOperator(task_id=task_id), run_id=run_id) - ti.dag_id = dag_id - session.add(ti) - - for i in [1, 2]: - if mapped_ti: - key = "test-xcom-key" - map_index = i - 1 - else: - key = f"test-xcom-key-{i}" - map_index = -1 - - XCom.set( - key=key, value="TEST", run_id=run_id, task_id=task_id, dag_id=dag_id, map_index=map_index - ) - - def _create_invalid_xcom_entries(self, logical_date): - """ - Invalid XCom entries to test join query - """ - with create_session() as session: - dag = DagModel(dag_id="invalid_dag") - session.add(dag) - if AIRFLOW_V_3_0_PLUS: - dagrun = DagRun( - dag_id="invalid_dag", - run_id="invalid_run_id", - logical_date=logical_date + timedelta(days=1), - run_after=logical_date, - start_date=logical_date, - run_type=DagRunType.MANUAL, - ) - else: - dagrun = DagRun( - dag_id="invalid_dag", - run_id="invalid_run_id", - execution_date=logical_date + timedelta(days=1), - start_date=logical_date, - run_type=DagRunType.MANUAL, - ) - session.add(dagrun) - if AIRFLOW_V_3_0_PLUS: - dagrun1 = DagRun( - dag_id="invalid_dag", - run_id="not_this_run_id", - logical_date=logical_date, - run_after=logical_date, - start_date=logical_date, - run_type=DagRunType.MANUAL, - ) - else: - dagrun1 = DagRun( - dag_id="invalid_dag", - run_id="not_this_run_id", - execution_date=logical_date, - start_date=logical_date, - run_type=DagRunType.MANUAL, - ) - session.add(dagrun1) - ti = TaskInstance(EmptyOperator(task_id="invalid_task"), run_id="not_this_run_id") - ti.dag_id = "invalid_dag" - session.add(ti) - for i in [1, 2]: - XCom.set( - key=f"invalid-xcom-key-{i}", - value="TEST", - run_id="not_this_run_id", - task_id="invalid_task", - dag_id="invalid_dag", - ) diff --git a/providers/fab/tests/unit/fab/auth_manager/conftest.py b/providers/fab/tests/unit/fab/auth_manager/conftest.py index 92d3dce0dc03d..49023851aa176 100644 --- a/providers/fab/tests/unit/fab/auth_manager/conftest.py +++ b/providers/fab/tests/unit/fab/auth_manager/conftest.py @@ -23,7 +23,7 @@ import pytest -from airflow.www import app +from airflow.providers.fab.www import app from tests_common.test_utils.config import conf_vars from tests_common.test_utils.db import parse_and_sync_to_db @@ -56,7 +56,7 @@ def factory(): ): "airflow.providers.fab.auth_manager.fab_auth_manager.FabAuthManager", } ): - _app = app.create_app(testing=True, config={"WTF_CSRF_ENABLED": False}) # type:ignore + _app = app.create_app(enable_plugins=False) _app.config["AUTH_ROLE_PUBLIC"] = None return _app diff --git a/providers/fab/tests/unit/fab/auth_manager/test_fab_auth_manager.py b/providers/fab/tests/unit/fab/auth_manager/test_fab_auth_manager.py index cbe4e0e5ac47c..0862d836deb41 100644 --- a/providers/fab/tests/unit/fab/auth_manager/test_fab_auth_manager.py +++ b/providers/fab/tests/unit/fab/auth_manager/test_fab_auth_manager.py @@ -27,6 +27,7 @@ from flask_appbuilder.menu import Menu from airflow.exceptions import AirflowConfigException, AirflowException +from airflow.providers.fab.www.extensions.init_appbuilder import init_appbuilder from unit.fab.auth_manager.api_endpoints.api_connexion_utils import create_user try: @@ -61,7 +62,6 @@ RESOURCE_VARIABLE, RESOURCE_WEBSITE, ) -from airflow.www.extensions.init_appbuilder import init_appbuilder if TYPE_CHECKING: from airflow.auth.managers.base_auth_manager import ResourceMethod @@ -94,7 +94,7 @@ def flask_app(): @pytest.fixture def auth_manager_with_appbuilder(flask_app): - appbuilder = init_appbuilder(flask_app) + appbuilder = init_appbuilder(flask_app, enable_plugins=False) auth_manager = FabAuthManager() auth_manager.appbuilder = appbuilder return auth_manager diff --git a/providers/fab/tests/unit/fab/auth_manager/test_security.py b/providers/fab/tests/unit/fab/auth_manager/test_security.py index daa7814a8873e..4ceca5cb3ac23 100644 --- a/providers/fab/tests/unit/fab/auth_manager/test_security.py +++ b/providers/fab/tests/unit/fab/auth_manager/test_security.py @@ -19,9 +19,7 @@ import contextlib import datetime -import json import logging -import os from typing import TYPE_CHECKING from unittest import mock from unittest.mock import patch @@ -32,10 +30,10 @@ from flask_appbuilder.views import BaseView, ModelView from sqlalchemy import Column, Date, Float, Integer, String -from airflow.configuration import initialize_config from airflow.exceptions import AirflowException from airflow.models import DagModel from airflow.models.dag import DAG +from airflow.providers.fab.www.utils import CustomSQLAInterface from tests_common.test_utils.compat import ignore_provider_compatibility_error @@ -45,11 +43,9 @@ from airflow.providers.fab.auth_manager.models.anonymous_user import AnonymousUser from airflow.api_fastapi.app import get_auth_manager +from airflow.providers.fab.www import app as application from airflow.providers.fab.www.security import permissions from airflow.providers.fab.www.security.permissions import ACTION_CAN_READ -from airflow.www import app as application -from airflow.www.auth import get_access_denied_message -from airflow.www.utils import CustomSQLAInterface from unit.fab.auth_manager.api_endpoints.api_connexion_utils import ( create_user, create_user_scope, @@ -180,7 +176,7 @@ def clear_db_before_test(): @pytest.fixture(scope="module") def app(): - _app = application.create_app(testing=True) + _app = application.create_app(enable_plugins=False) _app.config["WTF_CSRF_ENABLED"] = False return _app @@ -1113,81 +1109,3 @@ def test_users_can_be_found(app, security_manager, session, caplog): assert len(users) == 1 delete_user(app, "Test") assert "Error adding new user to database" in caplog.text - - -def test_default_access_denied_message(): - initialize_config() - assert get_access_denied_message() == "Access is Denied" - - -def test_custom_access_denied_message(): - with mock.patch.dict( - os.environ, - {"AIRFLOW__WEBSERVER__ACCESS_DENIED_MESSAGE": "My custom access denied message"}, - clear=True, - ): - initialize_config() - assert get_access_denied_message() == "My custom access denied message" - - -@pytest.mark.db_test -class TestHasAccessDagDecorator: - @pytest.mark.parametrize( - "dag_id_args, dag_id_kwargs, dag_id_form, dag_id_json, fail", - [ - ("a", None, None, None, False), - (None, "b", None, None, False), - (None, None, "c", None, False), - (None, None, None, "d", False), - ("a", "a", None, None, False), - ("a", "a", "a", None, False), - ("a", "a", "a", "a", False), - (None, "a", "a", "a", False), - (None, None, "a", "a", False), - ("a", None, None, "a", False), - ("a", None, "a", None, False), - ("a", None, "c", None, True), - (None, "b", "c", None, True), - (None, None, "c", "d", True), - ("a", "b", "c", "d", True), - ], - ) - def test_dag_id_consistency( - self, - app, - dag_id_args: str | None, - dag_id_kwargs: str | None, - dag_id_form: str | None, - dag_id_json: str | None, - fail: bool, - ): - with app.test_request_context() as mock_context: - from airflow.www.auth import has_access_dag - - mock_context.request.args = {"dag_id": dag_id_args} if dag_id_args else {} - kwargs = {"dag_id": dag_id_kwargs} if dag_id_kwargs else {} - mock_context.request.form = {"dag_id": dag_id_form} if dag_id_form else {} - if dag_id_json: - mock_context.request._cached_data = json.dumps({"dag_id": dag_id_json}) - mock_context.request._parsed_content_type = ["application/json"] - - with create_user_scope( - app, - username="test-user", - role_name="limited-role", - permissions=[(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG)], - ) as user: - with patch( - "airflow.providers.fab.auth_manager.fab_auth_manager.FabAuthManager.get_user" - ) as mock_get_user: - mock_get_user.return_value = user - - @has_access_dag("GET") - def test_func(**kwargs): - return True - - result = test_func(**kwargs) - if fail: - assert result[1] == 403 - else: - assert result is True diff --git a/providers/fab/tests/unit/fab/auth_manager/views/test_permissions.py b/providers/fab/tests/unit/fab/auth_manager/views/test_permissions.py index e87cb3f38c981..d1094650f8224 100644 --- a/providers/fab/tests/unit/fab/auth_manager/views/test_permissions.py +++ b/providers/fab/tests/unit/fab/auth_manager/views/test_permissions.py @@ -19,8 +19,8 @@ import pytest +from airflow.providers.fab.www import app as application from airflow.providers.fab.www.security import permissions -from airflow.www import app as application from unit.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user from unit.fab.auth_manager.views import _assert_dataset_deprecation_warning @@ -29,7 +29,7 @@ @pytest.fixture(scope="module") def fab_app(): - return application.create_app(testing=True) + return application.create_app(enable_plugins=False) @pytest.fixture(scope="module") diff --git a/providers/fab/tests/unit/fab/auth_manager/views/test_roles_list.py b/providers/fab/tests/unit/fab/auth_manager/views/test_roles_list.py index 282e7d897eb8f..3ce950788c36a 100644 --- a/providers/fab/tests/unit/fab/auth_manager/views/test_roles_list.py +++ b/providers/fab/tests/unit/fab/auth_manager/views/test_roles_list.py @@ -19,8 +19,8 @@ import pytest +from airflow.providers.fab.www import app as application from airflow.providers.fab.www.security import permissions -from airflow.www import app as application from unit.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user from unit.fab.auth_manager.views import _assert_dataset_deprecation_warning @@ -29,7 +29,7 @@ @pytest.fixture(scope="module") def fab_app(): - return application.create_app(testing=True) + return application.create_app(enable_plugins=False) @pytest.fixture(scope="module") diff --git a/providers/fab/tests/unit/fab/auth_manager/views/test_user.py b/providers/fab/tests/unit/fab/auth_manager/views/test_user.py index 3b34e1eb25214..be9dfa54fe7ea 100644 --- a/providers/fab/tests/unit/fab/auth_manager/views/test_user.py +++ b/providers/fab/tests/unit/fab/auth_manager/views/test_user.py @@ -19,8 +19,8 @@ import pytest +from airflow.providers.fab.www import app as application from airflow.providers.fab.www.security import permissions -from airflow.www import app as application from unit.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user from unit.fab.auth_manager.views import _assert_dataset_deprecation_warning @@ -29,7 +29,7 @@ @pytest.fixture(scope="module") def fab_app(): - return application.create_app(testing=True) + return application.create_app(enable_plugins=False) @pytest.fixture(scope="module") diff --git a/providers/fab/tests/unit/fab/auth_manager/views/test_user_edit.py b/providers/fab/tests/unit/fab/auth_manager/views/test_user_edit.py index 0e6a54907eeb0..734efed215b17 100644 --- a/providers/fab/tests/unit/fab/auth_manager/views/test_user_edit.py +++ b/providers/fab/tests/unit/fab/auth_manager/views/test_user_edit.py @@ -19,8 +19,8 @@ import pytest +from airflow.providers.fab.www import app as application from airflow.providers.fab.www.security import permissions -from airflow.www import app as application from unit.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user from unit.fab.auth_manager.views import _assert_dataset_deprecation_warning @@ -29,7 +29,7 @@ @pytest.fixture(scope="module") def fab_app(): - return application.create_app(testing=True) + return application.create_app(enable_plugins=False) @pytest.fixture(scope="module") diff --git a/providers/fab/tests/unit/fab/auth_manager/views/test_user_stats.py b/providers/fab/tests/unit/fab/auth_manager/views/test_user_stats.py index 7b0f8e04bad0b..51acb4147e370 100644 --- a/providers/fab/tests/unit/fab/auth_manager/views/test_user_stats.py +++ b/providers/fab/tests/unit/fab/auth_manager/views/test_user_stats.py @@ -19,8 +19,8 @@ import pytest +from airflow.providers.fab.www import app as application from airflow.providers.fab.www.security import permissions -from airflow.www import app as application from unit.fab.auth_manager.api_endpoints.api_connexion_utils import create_user, delete_user from unit.fab.auth_manager.views import _assert_dataset_deprecation_warning @@ -29,7 +29,7 @@ @pytest.fixture(scope="module") def fab_app(): - return application.create_app(testing=True) + return application.create_app(enable_plugins=False) @pytest.fixture(scope="module") diff --git a/providers/fab/tests/unit/fab/www/views/conftest.py b/providers/fab/tests/unit/fab/www/views/conftest.py index c064b8c9696b0..15f1d6261b65b 100644 --- a/providers/fab/tests/unit/fab/www/views/conftest.py +++ b/providers/fab/tests/unit/fab/www/views/conftest.py @@ -24,7 +24,7 @@ from airflow import settings from airflow.models import DagBag -from airflow.www.app import create_app +from airflow.providers.fab.www.app import create_app from unit.fab.auth_manager.api_endpoints.api_connexion_utils import delete_user from tests_common.test_utils.config import conf_vars @@ -63,7 +63,7 @@ def app(examples_dag_bag): ) def factory(): with conf_vars({("fab", "auth_rate_limited"): "False"}): - return create_app(testing=True) + return create_app(enable_plugins=False) app = factory() app.config["WTF_CSRF_ENABLED"] = False diff --git a/providers/fab/tests/unit/fab/www/views/test_views_acl.py b/providers/fab/tests/unit/fab/www/views/test_views_acl.py deleted file mode 100644 index dbc97b8a3799a..0000000000000 --- a/providers/fab/tests/unit/fab/www/views/test_views_acl.py +++ /dev/null @@ -1,1576 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -import datetime -import json -import urllib.parse - -import pytest - -from airflow import DAG, settings -from airflow.models import DagBag, DagModel, DagRun, TaskInstance, Variable -from airflow.models.errors import ParseImportError -from airflow.security import permissions -from airflow.utils import timezone -from airflow.utils.session import create_session -from airflow.utils.state import State -from airflow.utils.types import DagRunTriggeredByType, DagRunType -from airflow.www.views import FILTER_STATUS_COOKIE, DagRunModelView -from unit.fab.auth_manager.api_endpoints.api_connexion_utils import ( - create_test_client, - create_user, - create_user_scope, - delete_roles, - delete_user, -) - -from tests_common.test_utils.db import clear_db_runs -from tests_common.test_utils.permissions import _resource_name -from tests_common.test_utils.www import ( - capture_templates, # noqa: F401 - check_content_in_response, - check_content_not_in_response, - client_with_login, -) - -pytestmark = pytest.mark.db_test - -NEXT_YEAR = datetime.datetime.now().year + 1 -DEFAULT_DATE = timezone.datetime(NEXT_YEAR, 6, 1) -DEFAULT_RUN_ID = "TEST_RUN_ID" -USER_DATA = { - "dag_tester": ( - "dag_acl_tester", - { - "first_name": "dag_test", - "last_name": "dag_test", - "email": "dag_test@fab.org", - "password": "dag_test", - }, - ), - "dag_faker": ( # User without permission. - "dag_acl_faker", - { - "first_name": "dag_faker", - "last_name": "dag_faker", - "email": "dag_fake@fab.org", - "password": "dag_faker", - }, - ), - "dag_read_only": ( # User with only read permission. - "dag_acl_read_only", - { - "first_name": "dag_read_only", - "last_name": "dag_read_only", - "email": "dag_read_only@fab.org", - "password": "dag_read_only", - }, - ), - "all_dag_user": ( # User has all dag access. - "all_dag_role", - { - "first_name": "all_dag_user", - "last_name": "all_dag_user", - "email": "all_dag_user@fab.org", - "password": "all_dag_user", - }, - ), -} - - -def _get_appbuilder_pk_string(model_view_cls, instance) -> str: - """Utility to get Flask-Appbuilder's string format "pk" for an object. - - Used to generate requests to FAB action views without *too* much difficulty. - The implementation relies on FAB internals, but unfortunately I don't see - a better way around it. - - Example usage:: - - from airflow.www.views import TaskInstanceModelView - - ti = session.Query(TaskInstance).filter(...).one() - pk = _get_appbuilder_pk_string(TaskInstanceModelView, ti) - client.post("...", data={"action": "...", "rowid": pk}) - """ - pk_value = model_view_cls.datamodel.get_pk_value(instance) - return model_view_cls._serialize_pk_if_composite(model_view_cls, pk_value) - - -@pytest.fixture(scope="module") -def acl_app(app): - security_manager = app.appbuilder.sm - for username, (role_name, kwargs) in USER_DATA.items(): - if not security_manager.find_user(username=username): - role = security_manager.add_role(role_name) - security_manager.add_user( - role=role, - username=username, - **kwargs, - ) - - role_permissions = { - "dag_acl_tester": [ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_WEBSITE), - (permissions.ACTION_CAN_EDIT, "DAG:example_bash_operator"), - (permissions.ACTION_CAN_READ, "DAG:example_bash_operator"), - ], - "all_dag_role": [ - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_WEBSITE), - ], - "User": [ - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_WEBSITE), - ], - "dag_acl_read_only": [ - (permissions.ACTION_CAN_READ, "DAG:example_bash_operator"), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_WEBSITE), - ], - "dag_acl_faker": [(permissions.ACTION_CAN_READ, permissions.RESOURCE_WEBSITE)], - } - - for _role, _permissions in role_permissions.items(): - role = security_manager.find_role(_role) - for _action, _perm in _permissions: - perm = security_manager.get_permission(_action, _perm) - security_manager.add_permission_to_role(role, perm) - - yield app - - for username in USER_DATA: - user = security_manager.find_user(username=username) - if user: - security_manager.del_register_user(user) - - -@pytest.fixture(scope="module") -def _reset_dagruns(): - """Clean up stray garbage from other tests.""" - clear_db_runs() - - -@pytest.fixture(autouse=True) -def _init_dagruns(acl_app, _reset_dagruns): - acl_app.dag_bag.get_dag("example_bash_operator").create_dagrun( - run_id=DEFAULT_RUN_ID, - run_type=DagRunType.SCHEDULED, - logical_date=DEFAULT_DATE, - data_interval=(DEFAULT_DATE, DEFAULT_DATE), - start_date=timezone.utcnow(), - state=State.RUNNING, - run_after=DEFAULT_DATE, - triggered_by=DagRunTriggeredByType.TEST, - ) - acl_app.dag_bag.get_dag("example_python_operator").create_dagrun( - run_id=DEFAULT_RUN_ID, - run_type=DagRunType.SCHEDULED, - logical_date=DEFAULT_DATE, - start_date=timezone.utcnow(), - data_interval=(DEFAULT_DATE, DEFAULT_DATE), - state=State.RUNNING, - run_after=DEFAULT_DATE, - triggered_by=DagRunTriggeredByType.TEST, - ) - yield - clear_db_runs() - - -@pytest.fixture -def dag_test_client(acl_app): - return client_with_login(acl_app, username="dag_test", password="dag_test") - - -@pytest.fixture -def dag_faker_client(acl_app): - return client_with_login(acl_app, username="dag_faker", password="dag_faker") - - -@pytest.fixture -def all_dag_user_client(acl_app): - return client_with_login( - acl_app, - username="all_dag_user", - password="all_dag_user", - ) - - -@pytest.fixture(scope="module") -def user_edit_one_dag(acl_app): - with create_user_scope( - acl_app, - username="user_edit_one_dag", - role_name="role_edit_one_dag", - permissions=[ - (permissions.ACTION_CAN_READ, "DAG:example_bash_operator"), - (permissions.ACTION_CAN_EDIT, "DAG:example_bash_operator"), - ], - ) as user: - yield user - - -@pytest.mark.usefixtures("user_edit_one_dag") -def test_permission_exist(acl_app): - perms_views = acl_app.appbuilder.sm.get_resource_permissions( - acl_app.appbuilder.sm.get_resource("DAG:example_bash_operator"), - ) - assert len(perms_views) == 3 - - perms = {str(perm) for perm in perms_views} - assert "can read on DAG:example_bash_operator" in perms - assert "can edit on DAG:example_bash_operator" in perms - assert "can delete on DAG:example_bash_operator" in perms - - -@pytest.mark.usefixtures("user_edit_one_dag") -def test_role_permission_associate(acl_app): - test_role = acl_app.appbuilder.sm.find_role("role_edit_one_dag") - perms = {str(perm) for perm in test_role.permissions} - assert "can edit on DAG:example_bash_operator" in perms - assert "can read on DAG:example_bash_operator" in perms - - -@pytest.fixture(scope="module") -def user_all_dags(acl_app): - with create_user_scope( - acl_app, - username="user_all_dags", - role_name="role_all_dags", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_WEBSITE), - ], - ) as user: - yield user - - -@pytest.fixture -def client_all_dags(acl_app, user_all_dags): - return client_with_login( - acl_app, - username="user_all_dags", - password="user_all_dags", - ) - - -@pytest.fixture -def client_single_dag(app, user_single_dag): - """Client for User that can only access the first DAG from TEST_FILTER_DAG_IDS""" - return client_with_login( - app, - username="user_single_dag", - password="user_single_dag", - ) - - -@pytest.fixture(scope="module") -def client_dr_without_dag_run_create(app): - create_user( - app, - username="all_dr_permissions_except_dag_run_create", - role_name="all_dr_permissions_except_dag_run_create", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_WEBSITE), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_ACCESS_MENU, permissions.RESOURCE_DAG_RUN), - ], - ) - - yield client_with_login( - app, - username="all_dr_permissions_except_dag_run_create", - password="all_dr_permissions_except_dag_run_create", - ) - - delete_user(app, username="all_dr_permissions_except_dag_run_create") # type: ignore - delete_roles(app) - - -@pytest.fixture(scope="module") -def client_dr_without_dag_edit(app): - create_user( - app, - username="all_dr_permissions_except_dag_edit", - role_name="all_dr_permissions_except_dag_edit", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_WEBSITE), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_CREATE, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_ACCESS_MENU, permissions.RESOURCE_DAG_RUN), - ], - ) - - yield client_with_login( - app, - username="all_dr_permissions_except_dag_edit", - password="all_dr_permissions_except_dag_edit", - ) - - delete_user(app, username="all_dr_permissions_except_dag_edit") # type: ignore - delete_roles(app) - - -@pytest.fixture(scope="module") -def user_no_importerror(app): - """Create User that cannot access Import Errors""" - return create_user( - app, - username="user_no_importerrors", - role_name="role_no_importerrors", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_WEBSITE), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - ], - ) - - -@pytest.fixture -def client_no_importerror(app, user_no_importerror): - """Client for User that cannot access Import Errors""" - return client_with_login( - app, - username="user_no_importerrors", - password="user_no_importerrors", - ) - - -def test_index_for_all_dag_user(client_all_dags): - # The all dag user can access/view all dags. - resp = client_all_dags.get("/", follow_redirects=True) - check_content_in_response("example_python_operator", resp) - check_content_in_response("example_bash_operator", resp) - - -def test_index_failure(dag_test_client): - # This user can only access/view example_bash_operator dag. - resp = dag_test_client.get("/", follow_redirects=True) - check_content_not_in_response("example_python_operator", resp) - - -def test_dag_autocomplete_success(client_all_dags): - resp = client_all_dags.get( - "dagmodel/autocomplete?query=flow", - follow_redirects=False, - ) - expected = [ - {"name": "airflow", "type": "owner", "dag_display_name": None}, - { - "dag_display_name": None, - "name": "asset_alias_example_alias_consumer_with_no_taskflow", - "type": "dag", - }, - { - "dag_display_name": None, - "name": "asset_alias_example_alias_producer_with_no_taskflow", - "type": "dag", - }, - { - "dag_display_name": None, - "name": "asset_s3_bucket_consumer_with_no_taskflow", - "type": "dag", - }, - { - "dag_display_name": None, - "name": "asset_s3_bucket_producer_with_no_taskflow", - "type": "dag", - }, - { - "name": "example_dynamic_task_mapping_with_no_taskflow_operators", - "type": "dag", - "dag_display_name": None, - }, - {"name": "example_setup_teardown_taskflow", "type": "dag", "dag_display_name": None}, - {"name": "tutorial_taskflow_api", "type": "dag", "dag_display_name": None}, - {"name": "tutorial_taskflow_api_virtualenv", "type": "dag", "dag_display_name": None}, - {"name": "tutorial_taskflow_templates", "type": "dag", "dag_display_name": None}, - ] - - assert resp.json == expected - - -@pytest.mark.parametrize( - "query, expected", - [ - (None, []), - ("", []), - ("no-found", []), - ], - ids=["none", "empty", "not-found"], -) -def test_dag_autocomplete_empty(client_all_dags, query, expected): - url = "dagmodel/autocomplete" - if query is not None: - url = f"{url}?query={query}" - resp = client_all_dags.get(url, follow_redirects=False) - assert resp.json == expected - - -def test_dag_autocomplete_dag_display_name(client_all_dags): - url = "dagmodel/autocomplete?query=Sample" - resp = client_all_dags.get(url, follow_redirects=False) - assert resp.json == [ - {"name": "example_display_name", "type": "dag", "dag_display_name": "Sample DAG with Display Name"} - ] - - -@pytest.fixture -def _setup_paused_dag(): - """Pause a DAG so we can test filtering.""" - dag_to_pause = "example_branch_operator" - with create_session() as session: - session.query(DagModel).filter(DagModel.dag_id == dag_to_pause).update({"is_paused": True}) - yield - with create_session() as session: - session.query(DagModel).filter(DagModel.dag_id == dag_to_pause).update({"is_paused": False}) - - -@pytest.mark.parametrize( - "status, expected, unexpected", - [ - ("active", "example_branch_labels", "example_branch_operator"), - ("paused", "example_branch_operator", "example_branch_labels"), - ], -) -@pytest.mark.usefixtures("_setup_paused_dag") -def test_dag_autocomplete_status(client_all_dags, status, expected, unexpected): - with client_all_dags.session_transaction() as flask_session: - flask_session[FILTER_STATUS_COOKIE] = status - resp = client_all_dags.get( - "dagmodel/autocomplete?query=example_branch_", - follow_redirects=False, - ) - check_content_in_response(expected, resp) - check_content_not_in_response(unexpected, resp) - - -@pytest.fixture(scope="module") -def user_all_dags_dagruns(acl_app): - with create_user_scope( - acl_app, - username="user_all_dags_dagruns", - role_name="role_all_dags_dagruns", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_WEBSITE), - ], - ) as user: - yield user - - -@pytest.fixture -def client_all_dags_dagruns(acl_app, user_all_dags_dagruns): - return client_with_login( - acl_app, - username="user_all_dags_dagruns", - password="user_all_dags_dagruns", - ) - - -def test_dag_stats_success(client_all_dags_dagruns): - resp = client_all_dags_dagruns.post("dag_stats", follow_redirects=True) - check_content_in_response("example_bash_operator", resp) - assert set(next(iter(resp.json.items()))[1][0].keys()) == {"state", "count"} - - -def test_task_stats_failure(dag_test_client): - resp = dag_test_client.post("task_stats", follow_redirects=True) - check_content_not_in_response("example_python_operator", resp) - - -def test_dag_stats_success_for_all_dag_user(client_all_dags_dagruns): - resp = client_all_dags_dagruns.post("dag_stats", follow_redirects=True) - check_content_in_response("example_python_operator", resp) - check_content_in_response("example_bash_operator", resp) - - -@pytest.fixture(scope="module") -def user_all_dags_dagruns_tis(acl_app): - with create_user_scope( - acl_app, - username="user_all_dags_dagruns_tis", - role_name="role_all_dags_dagruns_tis", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_WEBSITE), - ], - ) as user: - yield user - - -@pytest.fixture -def client_all_dags_dagruns_tis(acl_app, user_all_dags_dagruns_tis): - return client_with_login( - acl_app, - username="user_all_dags_dagruns_tis", - password="user_all_dags_dagruns_tis", - ) - - -def test_task_stats_empty_success(client_all_dags_dagruns_tis): - resp = client_all_dags_dagruns_tis.post("task_stats", follow_redirects=True) - check_content_in_response("example_bash_operator", resp) - check_content_in_response("example_python_operator", resp) - - -@pytest.mark.parametrize( - "dags_to_run, unexpected_dag_ids", - [ - ( - ["example_python_operator"], - ["example_bash_operator", "example_xcom"], - ), - ( - ["example_python_operator", "example_bash_operator"], - ["example_xcom"], - ), - ], - ids=["single", "multi"], -) -def test_task_stats_success( - client_all_dags_dagruns_tis, - dags_to_run, - unexpected_dag_ids, -): - resp = client_all_dags_dagruns_tis.post( - "task_stats", data={"dag_ids": dags_to_run}, follow_redirects=True - ) - assert resp.status_code == 200 - for dag_id in unexpected_dag_ids: - check_content_not_in_response(dag_id, resp) - stats = json.loads(resp.data.decode()) - for dag_id in dags_to_run: - assert dag_id in stats - - -@pytest.fixture(scope="module") -def user_all_dags_codes(acl_app): - with create_user_scope( - acl_app, - username="user_all_dags_codes", - role_name="role_all_dags_codes", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_CODE), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_WEBSITE), - ], - ) as user: - yield user - - -@pytest.fixture -def client_all_dags_codes(acl_app, user_all_dags_codes): - return client_with_login( - acl_app, - username="user_all_dags_codes", - password="user_all_dags_codes", - ) - - -def test_code_success(client_all_dags_codes): - url = "code?dag_id=example_bash_operator" - resp = client_all_dags_codes.get(url, follow_redirects=True) - check_content_in_response("example_bash_operator", resp) - - -def test_code_failure(dag_test_client): - url = "code?dag_id=example_bash_operator" - resp = dag_test_client.get(url, follow_redirects=True) - check_content_not_in_response("example_bash_operator", resp) - - -@pytest.mark.parametrize( - "dag_id", - ["example_bash_operator", "example_python_operator"], -) -def test_code_success_for_all_dag_user(client_all_dags_codes, dag_id): - url = f"code?dag_id={dag_id}" - resp = client_all_dags_codes.get(url, follow_redirects=True) - check_content_in_response(dag_id, resp) - - -@pytest.mark.parametrize( - "dag_id", - ["example_bash_operator", "example_python_operator"], -) -def test_dag_details_success_for_all_dag_user(client_all_dags_dagruns, dag_id): - url = f"dag_details?dag_id={dag_id}" - resp = client_all_dags_dagruns.get(url, follow_redirects=True) - check_content_in_response(dag_id, resp) - - -@pytest.fixture(scope="module") -def user_all_dags_tis(acl_app): - with create_user_scope( - acl_app, - username="user_all_dags_tis", - role_name="role_all_dags_tis", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_WEBSITE), - ], - ) as user: - yield user - - -@pytest.fixture -def client_all_dags_tis(acl_app, user_all_dags_tis): - return client_with_login( - acl_app, - username="user_all_dags_tis", - password="user_all_dags_tis", - ) - - -@pytest.fixture(scope="module") -def user_all_dags_tis_xcom(acl_app): - with create_user_scope( - acl_app, - username="user_all_dags_tis_xcom", - role_name="role_all_dags_tis_xcom", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_XCOM), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_WEBSITE), - ], - ) as user: - yield user - - -@pytest.fixture -def client_all_dags_tis_xcom(acl_app, user_all_dags_tis_xcom): - return client_with_login( - acl_app, - username="user_all_dags_tis_xcom", - password="user_all_dags_tis_xcom", - ) - - -@pytest.fixture(scope="module") -def user_dags_tis_logs(acl_app): - with create_user_scope( - acl_app, - username="user_dags_tis_logs", - role_name="role_dags_tis_logs", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_LOG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_WEBSITE), - ], - ) as user: - yield user - - -@pytest.fixture -def client_dags_tis_logs(acl_app, user_dags_tis_logs): - return client_with_login( - acl_app, - username="user_dags_tis_logs", - password="user_dags_tis_logs", - ) - - -@pytest.fixture(scope="module") -def user_single_dag_edit(app): - """Create User that can edit DAG resource only a single DAG""" - return create_user( - app, - username="user_single_dag_edit", - role_name="role_single_dag", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_WEBSITE), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - ( - permissions.ACTION_CAN_EDIT, - _resource_name("filter_test_1", permissions.RESOURCE_DAG), - ), - ], - ) - - -@pytest.fixture -def client_single_dag_edit(app, user_single_dag_edit): - """Client for User that can only edit the first DAG from TEST_FILTER_DAG_IDS""" - return client_with_login( - app, - username="user_single_dag_edit", - password="user_single_dag_edit", - ) - - -RENDERED_TEMPLATES_URL = ( - f"rendered-templates?task_id=runme_0&dag_id=example_bash_operator&" - f"logical_date={urllib.parse.quote_plus(str(DEFAULT_DATE))}" -) -TASK_URL = ( - f"task?task_id=runme_0&dag_id=example_bash_operator&" - f"logical_date={urllib.parse.quote_plus(str(DEFAULT_DATE))}" -) -XCOM_URL = ( - f"xcom?task_id=runme_0&dag_id=example_bash_operator&" - f"logical_date={urllib.parse.quote_plus(str(DEFAULT_DATE))}" -) -DURATION_URL = "duration?days=30&dag_id=example_bash_operator" -TRIES_URL = "tries?days=30&dag_id=example_bash_operator" -LANDING_TIMES_URL = "landing_times?days=30&dag_id=example_bash_operator" -GANTT_URL = "gantt?dag_id=example_bash_operator" -GRID_DATA_URL = "object/grid_data?dag_id=example_bash_operator" -LOG_URL = ( - f"log?task_id=runme_0&dag_id=example_bash_operator&" - f"logical_date={urllib.parse.quote_plus(str(DEFAULT_DATE))}" -) - - -@pytest.mark.parametrize( - "client, url, expected_content", - [ - ("client_all_dags_tis", RENDERED_TEMPLATES_URL, "Rendered Template"), - ("all_dag_user_client", RENDERED_TEMPLATES_URL, "Rendered Template"), - ("client_all_dags_tis", TASK_URL, "Task Instance Details"), - ("client_all_dags_tis_xcom", XCOM_URL, "XCom"), - ("client_all_dags_tis", DURATION_URL, "example_bash_operator"), - ("client_all_dags_tis", TRIES_URL, "example_bash_operator"), - ("client_all_dags_tis", LANDING_TIMES_URL, "example_bash_operator"), - ("client_all_dags_tis", GANTT_URL, "example_bash_operator"), - ("client_dags_tis_logs", GRID_DATA_URL, "runme_1"), - ("viewer_client", GRID_DATA_URL, "runme_1"), - ("client_dags_tis_logs", LOG_URL, "Log by attempts"), - ("user_client", LOG_URL, "Log by attempts"), - ], - ids=[ - "rendered-templates", - "rendered-templates-all-dag-user", - "task", - "xcom", - "duration", - "tries", - "landing-times", - "gantt", - "grid-data-for-readonly-role", - "grid-data-for-viewer", - "log", - "log-for-user", - ], -) -def test_success(request, client, url, expected_content): - resp = request.getfixturevalue(client).get(url, follow_redirects=True) - check_content_in_response(expected_content, resp) - - -@pytest.mark.parametrize( - "url, unexpected_content", - [ - (RENDERED_TEMPLATES_URL, "Rendered Template"), - (TASK_URL, "Task Instance Details"), - (XCOM_URL, "XCom"), - (DURATION_URL, "example_bash_operator"), - (TRIES_URL, "example_bash_operator"), - (LANDING_TIMES_URL, "example_bash_operator"), - (GANTT_URL, "example_bash_operator"), - (LOG_URL, "Log by attempts"), - ], - ids=[ - "rendered-templates", - "task", - "xcom", - "duration", - "tries", - "landing-times", - "gantt", - "log", - ], -) -def test_failure(dag_faker_client, url, unexpected_content): - resp = dag_faker_client.get(url, follow_redirects=True) - check_content_not_in_response(unexpected_content, resp) - - -def test_blocked_success(client_all_dags_dagruns): - resp = client_all_dags_dagruns.post("blocked") - check_content_in_response("example_bash_operator", resp) - - -def test_blocked_success_for_all_dag_user(all_dag_user_client): - resp = all_dag_user_client.post("blocked") - check_content_in_response("example_bash_operator", resp) - check_content_in_response("example_python_operator", resp) - - -def test_blocked_viewer(viewer_client): - resp = viewer_client.post("blocked") - check_content_in_response("example_bash_operator", resp) - - -@pytest.mark.parametrize( - "dags_to_block, unexpected_dag_ids", - [ - ( - ["example_python_operator"], - ["example_bash_operator", "example_xcom"], - ), - ( - ["example_python_operator", "example_bash_operator"], - ["example_xcom"], - ), - ], - ids=["single", "multi"], -) -def test_blocked_success_when_selecting_dags( - admin_client, - dags_to_block, - unexpected_dag_ids, -): - resp = admin_client.post("blocked", data={"dag_ids": dags_to_block}) - assert resp.status_code == 200 - for dag_id in unexpected_dag_ids: - check_content_not_in_response(dag_id, resp) - blocked_dags = {blocked["dag_id"] for blocked in json.loads(resp.data.decode())} - for dag_id in dags_to_block: - assert dag_id in blocked_dags - - -@pytest.fixture(scope="module") -def user_all_dags_edit_tis(acl_app): - with create_user_scope( - acl_app, - username="user_all_dags_edit_tis", - role_name="role_all_dags_edit_tis", - permissions=[ - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_TASK_INSTANCE), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_WEBSITE), - ], - ) as user: - yield user - - -@pytest.fixture -def client_all_dags_edit_tis(acl_app, user_all_dags_edit_tis): - return client_with_login( - acl_app, - username="user_all_dags_edit_tis", - password="user_all_dags_edit_tis", - ) - - -def test_failed_success(client_all_dags_edit_tis): - form = dict( - task_id="run_this_last", - dag_id="example_bash_operator", - dag_run_id=DEFAULT_RUN_ID, - upstream="false", - downstream="false", - future="false", - past="false", - ) - resp = client_all_dags_edit_tis.post("failed", data=form, follow_redirects=True) - check_content_in_response("Marked failed on 1 task instances", resp) - - -def test_paused_post_success(dag_test_client): - resp = dag_test_client.post("paused?dag_id=example_bash_operator&is_paused=false", follow_redirects=True) - check_content_in_response("OK", resp) - - -@pytest.fixture(scope="module") -def user_only_dags_tis(acl_app): - with create_user_scope( - acl_app, - username="user_only_dags_tis", - role_name="role_only_dags_tis", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_WEBSITE), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), - ], - ) as user: - yield user - - -@pytest.fixture -def client_only_dags_tis(acl_app, user_only_dags_tis): - return client_with_login( - acl_app, - username="user_only_dags_tis", - password="user_only_dags_tis", - ) - - -def test_success_fail_for_read_only_task_instance_access(client_only_dags_tis): - form = dict( - task_id="run_this_last", - dag_id="example_bash_operator", - dag_run_id=DEFAULT_RUN_ID, - upstream="false", - downstream="false", - future="false", - past="false", - ) - resp = client_only_dags_tis.post("success", data=form) - check_content_not_in_response("Please confirm", resp, resp_code=302) - - -GET_LOGS_WITH_METADATA_URL = ( - f"get_logs_with_metadata?task_id=runme_0&dag_id=example_bash_operator&" - f"logical_date={urllib.parse.quote_plus(str(DEFAULT_DATE))}&" - f"try_number=1&metadata=null" -) - - -@pytest.mark.parametrize("client", ["client_dags_tis_logs", "user_client"]) -def test_get_logs_with_metadata_success(request, client): - resp = request.getfixturevalue(client).get( - GET_LOGS_WITH_METADATA_URL, - follow_redirects=True, - ) - check_content_in_response('"message":', resp) - check_content_in_response('"metadata":', resp) - - -def test_get_logs_with_metadata_failure(dag_faker_client): - resp = dag_faker_client.get( - GET_LOGS_WITH_METADATA_URL, - follow_redirects=True, - ) - check_content_not_in_response('"message":', resp) - check_content_not_in_response('"metadata":', resp) - - -@pytest.fixture(scope="module") -def user_no_roles(acl_app): - with create_user_scope(acl_app, username="no_roles_user", role_name="no_roles_user_role") as user: - user.roles = [] - yield user - - -@pytest.fixture -def client_no_roles(acl_app, user_no_roles): - return client_with_login( - acl_app, - username="no_roles_user", - password="no_roles_user", - ) - - -@pytest.fixture(scope="module") -def user_no_permissions(acl_app): - with create_user_scope( - acl_app, - username="no_permissions_user", - role_name="no_permissions_role", - ) as user: - yield user - - -@pytest.fixture -def client_no_permissions(acl_app, user_no_permissions): - return client_with_login( - acl_app, - username="no_permissions_user", - password="no_permissions_user", - ) - - -@pytest.fixture -def client_anonymous(acl_app): - return acl_app.test_client() - - -@pytest.fixture -def running_dag_run(session): - dag = DagBag().get_dag("example_bash_operator") - logical_date = timezone.datetime(2016, 1, 9) - dr = dag.create_dagrun( - state="running", - logical_date=logical_date, - data_interval=(logical_date, logical_date), - run_id="test_dag_runs_action", - run_type=DagRunType.MANUAL, - session=session, - run_after=logical_date, - triggered_by=DagRunTriggeredByType.TEST, - ) - session.add(dr) - tis = [ - TaskInstance(dag.get_task("runme_0"), run_id=dr.run_id, state="success"), - TaskInstance(dag.get_task("runme_1"), run_id=dr.run_id, state="failed"), - ] - session.bulk_save_objects(tis) - session.commit() - return dr - - -@pytest.fixture -def _working_dags(dag_maker): - for dag_id, tag in zip(TEST_FILTER_DAG_IDS, TEST_TAGS): - with dag_maker(dag_id=dag_id, fileloc=f"/{dag_id}.py", tags=[tag]): - # We need to enter+exit the dag maker context for it to create the dag - pass - - -@pytest.fixture -def _working_dags_with_read_perm(dag_maker): - for dag_id, tag in zip(TEST_FILTER_DAG_IDS, TEST_TAGS): - if dag_id == "filter_test_1": - access_control = {"role_single_dag": {"can_read"}} - else: - access_control = None - - with dag_maker(dag_id=dag_id, fileloc=f"/{dag_id}.py", tags=[tag], access_control=access_control): - pass - - -@pytest.fixture -def _working_dags_with_edit_perm(dag_maker): - for dag_id, tag in zip(TEST_FILTER_DAG_IDS, TEST_TAGS): - if dag_id == "filter_test_1": - access_control = {"role_single_dag": {"can_edit"}} - else: - access_control = None - - with dag_maker(dag_id=dag_id, fileloc=f"/{dag_id}.py", tags=[tag], access_control=access_control): - pass - - -@pytest.fixture -def _broken_dags(session): - from airflow.models.errors import ParseImportError - - for dag_id in TEST_FILTER_DAG_IDS: - session.add( - ParseImportError( - filename=f"/{dag_id}.py", bundle_name="dag_maker", stacktrace="Some Error\nTraceback:\n" - ) - ) - session.commit() - - -@pytest.fixture -def _broken_dags_after_working(dag_maker, session): - # First create and process a DAG file that works - path = "/all_in_one.py" - for dag_id in TEST_FILTER_DAG_IDS: - with dag_maker(dag_id=dag_id, fileloc=path, session=session): - pass - - # Then create an import error against that file - session.add( - ParseImportError(filename=path, bundle_name="dag_maker", stacktrace="Some Error\nTraceback:\n") - ) - session.commit() - - -@pytest.mark.parametrize( - "client, url, status_code, expected_content", - [ - ["client_no_roles", "/home", 403, "Your user has no roles and/or permissions!"], - ["client_no_permissions", "/home", 403, "Your user has no roles and/or permissions!"], - ["client_all_dags", "/home", 200, "DAGs - Airflow"], - ["client_anonymous", "/home", 200, "Sign In"], - ], -) -def test_no_roles_permissions(request, client, url, status_code, expected_content): - resp = request.getfixturevalue(client).get(url, follow_redirects=True) - check_content_in_response(expected_content, resp, status_code) - - -@pytest.fixture(scope="module") -def user_dag_level_access_with_ti_edit(acl_app): - with create_user_scope( - acl_app, - username="user_dag_level_access_with_ti_edit", - role_name="role_dag_level_access_with_ti_edit", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_WEBSITE), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_TASK_INSTANCE), - ( - permissions.ACTION_CAN_EDIT, - _resource_name("example_bash_operator", permissions.RESOURCE_DAG), - ), - ], - ) as user: - yield user - - -@pytest.fixture -def client_dag_level_access_with_ti_edit(acl_app, user_dag_level_access_with_ti_edit): - return client_with_login( - acl_app, - username="user_dag_level_access_with_ti_edit", - password="user_dag_level_access_with_ti_edit", - ) - - -def test_success_edit_ti_with_dag_level_access_only(client_dag_level_access_with_ti_edit): - form = dict( - task_id="run_this_last", - dag_id="example_bash_operator", - dag_run_id=DEFAULT_RUN_ID, - upstream="false", - downstream="false", - future="false", - past="false", - ) - resp = client_dag_level_access_with_ti_edit.post("/success", data=form, follow_redirects=True) - check_content_in_response("Marked success on 1 task instances", resp) - - -@pytest.fixture(scope="module") -def user_ti_edit_without_dag_level_access(acl_app): - with create_user_scope( - acl_app, - username="user_ti_edit_without_dag_level_access", - role_name="role_ti_edit_without_dag_level_access", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_WEBSITE), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_TASK_INSTANCE), - ], - ) as user: - yield user - - -@pytest.fixture -def client_ti_edit_without_dag_level_access(acl_app, user_ti_edit_without_dag_level_access): - return client_with_login( - acl_app, - username="user_ti_edit_without_dag_level_access", - password="user_ti_edit_without_dag_level_access", - ) - - -@pytest.fixture(scope="module", autouse=True) -def _init_blank_dagrun(): - """Make sure there are no runs before we test anything. - - This really shouldn't be needed, but tests elsewhere leave the db dirty. - """ - with create_session() as session: - session.query(DagRun).delete() - session.query(TaskInstance).delete() - - -@pytest.fixture(autouse=True) -def _reset_dagrun(): - yield - with create_session() as session: - session.query(DagRun).delete() - session.query(TaskInstance).delete() - - -@pytest.fixture -def one_dag_perm_user_client(app): - username = "test_user_one_dag_perm" - dag_id = "example_bash_operator" - sm = app.appbuilder.sm - perm = f"{permissions.RESOURCE_DAG_PREFIX}{dag_id}" - - sm.create_permission(permissions.ACTION_CAN_READ, perm) - - create_user( - app, - username=username, - role_name="User with permission to access only one dag", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_LOG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_WEBSITE), - (permissions.ACTION_CAN_READ, perm), - ], - ) - - sm.find_user(username=username) - - yield client_with_login( - app, - username=username, - password=username, - ) - - delete_user(app, username=username) # type: ignore - delete_roles(app) - - -@pytest.fixture -def new_dag_to_delete(testing_dag_bundle): - dag = DAG( - "new_dag_to_delete", is_paused_upon_creation=True, schedule="0 * * * *", start_date=DEFAULT_DATE - ) - session = settings.Session() - DAG.bulk_write_to_db("testing", None, [dag], session=session) - return dag - - -@pytest.fixture -def per_dag_perm_user_client(app, new_dag_to_delete): - sm = app.appbuilder.sm - perm = f"{permissions.RESOURCE_DAG_PREFIX}{new_dag_to_delete.dag_id}" - - sm.create_permission(permissions.ACTION_CAN_DELETE, perm) - - create_user( - app, - username="test_user_per_dag_perms", - role_name="User with some perms", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_DELETE, perm), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_WEBSITE), - ], - ) - - sm.find_user(username="test_user_per_dag_perms") - - yield client_with_login( - app, - username="test_user_per_dag_perms", - password="test_user_per_dag_perms", - ) - - delete_user(app, username="test_user_per_dag_perms") # type: ignore - delete_roles(app) - - -@pytest.fixture(scope="module") -def client_ti_without_dag_edit(app): - create_user( - app, - username="all_ti_permissions_except_dag_edit", - role_name="all_ti_permissions_except_dag_edit", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_WEBSITE), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_CREATE, permissions.RESOURCE_TASK_INSTANCE), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_TASK_INSTANCE), - (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_TASK_INSTANCE), - (permissions.ACTION_CAN_ACCESS_MENU, permissions.RESOURCE_TASK_INSTANCE), - (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_DAG_RUN), - ], - ) - - yield client_with_login( - app, - username="all_ti_permissions_except_dag_edit", - password="all_ti_permissions_except_dag_edit", - ) - - delete_user(app, username="all_ti_permissions_except_dag_edit") # type: ignore - delete_roles(app) - - -def test_failure_edit_ti_without_dag_level_access(client_ti_edit_without_dag_level_access): - form = dict( - task_id="run_this_last", - dag_id="example_bash_operator", - dag_run_id=DEFAULT_RUN_ID, - upstream="false", - downstream="false", - future="false", - past="false", - ) - resp = client_ti_edit_without_dag_level_access.post("/success", data=form, follow_redirects=True) - check_content_not_in_response("Marked success on 1 task instances", resp) - - -def test_viewer_cant_trigger_dag(app): - """ - Test that the test_viewer user can't trigger DAGs. - """ - with create_test_client( - app, - user_name="test_user", - role_name="test_role", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_WEBSITE), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_CREATE, permissions.RESOURCE_DAG_RUN), - ], - ) as client: - url = "dags/example_bash_operator/trigger" - resp = client.get(url, follow_redirects=True) - response_data = resp.data.decode() - assert "Access is Denied" in response_data - - -def test_get_dagrun_can_view_dags_without_edit_perms(session, running_dag_run, client_dr_without_dag_edit): - """Test that a user without dag_edit but with dag_read permission can view the records""" - assert session.query(DagRun).filter(DagRun.dag_id == running_dag_run.dag_id).count() == 2 - resp = client_dr_without_dag_edit.get("/dagrun/list/", follow_redirects=True) - check_content_in_response(running_dag_run.dag_id, resp) - - -def test_create_dagrun_permission_denied(session, client_dr_without_dag_run_create): - data = { - "state": "running", - "dag_id": "example_bash_operator", - "logical_date": "2018-07-06 05:06:03", - "run_id": "test_list_dagrun_includes_conf", - "conf": '{"include": "me"}', - } - - resp = client_dr_without_dag_run_create.post("/dagrun/add", data=data, follow_redirects=True) - check_content_in_response("Access is Denied", resp) - - -def test_delete_dagrun_permission_denied(session, running_dag_run, client_dr_without_dag_edit): - composite_key = _get_appbuilder_pk_string(DagRunModelView, running_dag_run) - - assert session.query(DagRun).filter(DagRun.dag_id == running_dag_run.dag_id).count() == 2 - resp = client_dr_without_dag_edit.post(f"/dagrun/delete/{composite_key}", follow_redirects=True) - check_content_in_response("Access is Denied", resp) - assert session.query(DagRun).filter(DagRun.dag_id == running_dag_run.dag_id).count() == 2 - - -@pytest.mark.parametrize( - "action", - ["clear", "set_success", "set_failed", "set_running"], - ids=["clear", "success", "failed", "running"], -) -def test_set_dag_runs_action_permission_denied(client_dr_without_dag_edit, running_dag_run, action): - running_dag_id = running_dag_run.id - resp = client_dr_without_dag_edit.post( - "/dagrun/action_post", - data={"action": action, "rowid": [str(running_dag_id)]}, - follow_redirects=True, - ) - check_content_in_response("Access is Denied", resp) - - -def test_delete_dagrun(session, admin_client, running_dag_run): - composite_key = _get_appbuilder_pk_string(DagRunModelView, running_dag_run) - assert session.query(DagRun).filter(DagRun.dag_id == running_dag_run.dag_id).count() == 2 - admin_client.post(f"/dagrun/delete/{composite_key}", follow_redirects=True) - assert session.query(DagRun).filter(DagRun.dag_id == running_dag_run.dag_id).count() == 1 - - -@pytest.mark.usefixtures("_broken_dags", "_working_dags") -def test_home_no_importerrors_perm(_broken_dags, client_no_importerror): - # Users without "can read on import errors" don't see any import errors - resp = client_no_importerror.get("home", follow_redirects=True) - check_content_not_in_response("Import Errors", resp) - - -TEST_FILTER_DAG_IDS = ["filter_test_1", "filter_test_2", "a_first_dag_id_asc", "filter.test"] -TEST_TAGS = ["example", "test", "team", "group"] - - -@pytest.fixture(scope="module") -def user_single_dag(app): - """Create User that can only access the first DAG from TEST_FILTER_DAG_IDS""" - return create_user( - app, - username="user_single_dag", - role_name="role_single_dag", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_WEBSITE), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_IMPORT_ERROR), - ( - permissions.ACTION_CAN_READ, - _resource_name(TEST_FILTER_DAG_IDS[0], permissions.RESOURCE_DAG), - ), - ], - ) - - -@pytest.fixture -def testing_dag_bundle(): - from airflow.models.dagbundle import DagBundleModel - from airflow.utils.session import create_session - - with create_session() as session: - if session.query(DagBundleModel).filter(DagBundleModel.name == "testing").count() == 0: - testing = DagBundleModel(name="testing") - session.add(testing) - - -@pytest.fixture -def client_variable_reader(app, user_variable_reader): - """Client for User that can only access the first DAG from TEST_FILTER_DAG_IDS""" - return client_with_login( - app, - username="user_variable_reader", - password="user_variable_reader", - ) - - -VARIABLE = { - "key": "test_key", - "val": "text_val", - "description": "test_description", - "is_encrypted": True, -} - - -@pytest.fixture(autouse=True) -def _clear_variables(): - with create_session() as session: - session.query(Variable).delete() - - -@pytest.fixture(scope="module") -def user_variable_reader(app): - """Create User that can only read variables""" - return create_user( - app, - username="user_variable_reader", - role_name="role_variable_reader", - permissions=[ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_VARIABLE), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_WEBSITE), - ], - ) - - -@pytest.fixture -def variable(session): - variable = Variable( - key=VARIABLE["key"], - val=VARIABLE["val"], - description=VARIABLE["description"], - ) - session.add(variable) - session.commit() - yield variable - session.query(Variable).filter(Variable.key == VARIABLE["key"]).delete() - session.commit() - - -@pytest.mark.parametrize( - "page", - [ - "home", - "home?status=all", - "home?status=active", - "home?status=paused", - "home?lastrun=running", - "home?lastrun=failed", - "home?lastrun=all_states", - ], -) -@pytest.mark.usefixtures("_working_dags_with_read_perm", "_broken_dags") -def test_home_importerrors_filtered_singledag_user(client_single_dag, page): - # Users that can only see certain DAGs get a filtered list of import errors - resp = client_single_dag.get(page, follow_redirects=True) - check_content_in_response("Import Errors", resp) - # They can see the first DAGs import error - check_content_in_response(f"/{TEST_FILTER_DAG_IDS[0]}.py", resp) - check_content_in_response("Traceback", resp) - # But not the rest - for dag_id in TEST_FILTER_DAG_IDS[1:]: - check_content_not_in_response(f"/{dag_id}.py", resp) - - -def test_home_importerrors_missing_read_on_all_dags_in_file(_broken_dags_after_working, client_single_dag): - # If a user doesn't have READ on all DAGs in a file, that files traceback is redacted - resp = client_single_dag.get("home", follow_redirects=True) - check_content_in_response("Import Errors", resp) - # They can see the DAG file has an import error - check_content_in_response("all_in_one.py", resp) - # And the traceback is redacted - check_content_not_in_response("Traceback", resp) - check_content_in_response("REDACTED", resp) - - -def test_home_dag_list_filtered_singledag_user(_working_dags_with_read_perm, client_single_dag): - # Users that can only see certain DAGs get a filtered list - resp = client_single_dag.get("home", follow_redirects=True) - # They can see the first DAG - check_content_in_response(f"dag_id={TEST_FILTER_DAG_IDS[0]}", resp) - # But not the rest - for dag_id in TEST_FILTER_DAG_IDS[1:]: - check_content_not_in_response(f"dag_id={dag_id}", resp) - - -def test_home_dag_edit_permissions( - capture_templates, # noqa: F811 - _working_dags_with_edit_perm, - client_single_dag_edit, -): - with capture_templates() as templates: - client_single_dag_edit.get("home", follow_redirects=True) - - dags = templates[0].local_context["dags"] - assert len(dags) > 0 - dag_edit_perm_tuple = [(dag.dag_id, dag.can_edit) for dag in dags] - assert ("filter_test_1", True) in dag_edit_perm_tuple - assert ("filter_test_2", False) in dag_edit_perm_tuple - - -def test_graph_view_without_dag_permission(app, one_dag_perm_user_client): - url = "/dags/example_bash_operator/graph" - resp = one_dag_perm_user_client.get(url, follow_redirects=True) - assert resp.status_code == 200 - assert ( - resp.request.url - == "http://localhost/dags/example_bash_operator/grid?tab=graph&dag_run_id=TEST_RUN_ID" - ) - check_content_in_response("example_bash_operator", resp) - - url = "/dags/example_xcom/graph" - resp = one_dag_perm_user_client.get(url, follow_redirects=True) - assert resp.status_code == 200 - assert resp.request.url == "http://localhost/home" - check_content_in_response("Access is Denied", resp) - - -def test_delete_just_dag_per_dag_permissions(new_dag_to_delete, per_dag_perm_user_client): - resp = per_dag_perm_user_client.post( - f"delete?dag_id={new_dag_to_delete.dag_id}&next=/home", follow_redirects=True - ) - check_content_in_response(f"Deleting DAG with id {new_dag_to_delete.dag_id}.", resp) - - -def test_import_variables_form_hidden(app, client_variable_reader): - resp = client_variable_reader.get("/variable/list/") - check_content_not_in_response("Import Variables", resp) - - -def test_action_muldelete_access_denied(session, client_variable_reader, variable): - var_id = variable.id - resp = client_variable_reader.post( - "/variable/action_post", - data={"action": "muldelete", "rowid": [var_id]}, - follow_redirects=True, - ) - check_content_in_response("Access is Denied", resp) diff --git a/providers/fab/tests/unit/fab/www/views/test_views_custom_user_views.py b/providers/fab/tests/unit/fab/www/views/test_views_custom_user_views.py index cdf02cd92fe18..2d363527646c6 100644 --- a/providers/fab/tests/unit/fab/www/views/test_views_custom_user_views.py +++ b/providers/fab/tests/unit/fab/www/views/test_views_custom_user_views.py @@ -25,8 +25,8 @@ from flask_appbuilder import SQLA from airflow import settings +from airflow.providers.fab.www import app as application from airflow.security import permissions -from airflow.www import app as application from unit.fab.auth_manager.api_endpoints.api_connexion_utils import ( create_user, delete_role, @@ -34,7 +34,6 @@ from tests_common.test_utils.www import ( check_content_in_response, - check_content_not_in_response, client_with_login, ) @@ -75,7 +74,7 @@ def setup_method(self): # an exception because app context teardown is removed and if even single request is run via app # it cannot be re-intialized again by passing it as constructor to SQLA # This makes the tests slightly slower (but they work with Flask 2.1 and 2.2 - self.app = application.create_app(testing=True) + self.app = application.create_app(enable_plugins=False) self.appbuilder = self.app.appbuilder self.app.config["WTF_CSRF_ENABLED"] = False self.security_manager = self.appbuilder.sm @@ -89,7 +88,7 @@ def delete_roles(self): delete_role(self.app, role_name) @pytest.mark.parametrize("url, _, expected_text", PERMISSIONS_TESTS_PARAMS) - def test_user_model_view_with_access(self, url, expected_text, _): + def test_user_model_view_without_access(self, url, expected_text, _): user_without_access = create_user( self.app, username="no_access", @@ -103,11 +102,12 @@ def test_user_model_view_with_access(self, url, expected_text, _): username="no_access", password="no_access", ) - response = client.get(url.replace("{user.id}", str(user_without_access.id)), follow_redirects=True) - check_content_not_in_response(expected_text, response) + response = client.get(url.replace("{user.id}", str(user_without_access.id)), follow_redirects=False) + assert response.status_code == 302 + assert response.location.startswith("/login/") @pytest.mark.parametrize("url, permission, expected_text", PERMISSIONS_TESTS_PARAMS) - def test_user_model_view_without_access(self, url, permission, expected_text): + def test_user_model_view_with_access(self, url, permission, expected_text): user_with_access = create_user( self.app, username="has_access", @@ -145,9 +145,9 @@ def test_user_model_view_without_delete_access(self): password="no_access", ) - response = client.post(f"/users/delete/{user_to_delete.id}", follow_redirects=True) - - check_content_not_in_response("Deleted Row", response) + response = client.post(f"/users/delete/{user_to_delete.id}", follow_redirects=False) + assert response.status_code == 302 + assert response.location.startswith("/login/") assert bool(self.security_manager.get_user_by_id(user_to_delete.id)) is True def test_user_model_view_with_delete_access(self): @@ -173,15 +173,10 @@ def test_user_model_view_with_delete_access(self): password="has_access", ) - response = client.post(f"/users/delete/{user_to_delete.id}", follow_redirects=True) - check_content_in_response("Deleted Row", response) - check_content_not_in_response(user_to_delete.username, response) + client.post(f"/users/delete/{user_to_delete.id}", follow_redirects=True) assert bool(self.security_manager.get_user_by_id(user_to_delete.id)) is False -# type: ignore[attr-defined] - - class TestResetUserSessions: @classmethod def setup_class(cls): @@ -192,7 +187,7 @@ def setup_method(self): # an exception because app context teardown is removed and if even single request is run via app # it cannot be re-intialized again by passing it as constructor to SQLA # This makes the tests slightly slower (but they work with Flask 2.1 and 2.2 - self.app = application.create_app(testing=True) + self.app = application.create_app(enable_plugins=False) self.appbuilder = self.app.appbuilder self.app.config["WTF_CSRF_ENABLED"] = False self.security_manager = self.appbuilder.sm