Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -185,20 +185,14 @@ def get_cli_commands() -> list[CLICommand]:
return commands

def get_fastapi_app(self) -> FastAPI | None:
flask_blueprint = self.get_api_endpoints()

if not flask_blueprint:
return None

flask_app = create_app(enable_plugins=False)
flask_app.register_blueprint(flask_blueprint)

app = FastAPI(
title="FAB auth manager API",
description=(
"This is FAB auth manager API. This API is only available if the auth manager used in "
"the Airflow environment is FAB auth manager. "
"This API provides endpoints to manager users and permissions managed by the FAB auth "
"This API provides endpoints to manage users and permissions managed by the FAB auth "
"manager."
),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,9 @@
from airflow.providers.fab.www.security import permissions
from airflow.providers.fab.www.security_manager import AirflowSecurityManagerV2
from airflow.providers.fab.www.session import (
AirflowDatabaseSessionInterface,
AirflowDatabaseSessionInterface as FabAirflowDatabaseSessionInterface,
)
from airflow.www.session import AirflowDatabaseSessionInterface

if TYPE_CHECKING:
from airflow.providers.fab.www.security.permissions import RESOURCE_ASSET
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,16 @@
from __future__ import annotations

from http import HTTPStatus
from typing import Any
from typing import TYPE_CHECKING, Any

from connexion import ProblemException
import werkzeug
from connexion import FlaskApi, ProblemException, problem

from airflow.utils.docs import get_docs_url

if TYPE_CHECKING:
import flask

doc_link = get_docs_url("stable-rest-api-ref.html")

EXCEPTIONS_LINK_MAP = {
Expand All @@ -36,6 +40,39 @@
}


def common_error_handler(exception: BaseException) -> flask.Response:
"""Use to capture connexion exceptions and add link to the type field."""
if isinstance(exception, ProblemException):
link = EXCEPTIONS_LINK_MAP.get(exception.status)
if link:
response = problem(
status=exception.status,
title=exception.title,
detail=exception.detail,
type=link,
instance=exception.instance,
headers=exception.headers,
ext=exception.ext,
)
else:
response = problem(
status=exception.status,
title=exception.title,
detail=exception.detail,
type=exception.type,
instance=exception.instance,
headers=exception.headers,
ext=exception.ext,
)
else:
if not isinstance(exception, werkzeug.exceptions.HTTPException):
exception = werkzeug.exceptions.InternalServerError()

response = problem(title=exception.name, detail=exception.description, status=exception.code)

return FlaskApi.get_response(response)


class NotFound(ProblemException):
"""Raise when the object cannot be found."""

Expand Down
16 changes: 14 additions & 2 deletions providers/fab/src/airflow/providers/fab/www/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,13 @@
from airflow.providers.fab.www.extensions.init_jinja_globals import init_jinja_globals
from airflow.providers.fab.www.extensions.init_manifest_files import configure_manifest_files
from airflow.providers.fab.www.extensions.init_security import init_api_auth, init_xframe_protection
from airflow.providers.fab.www.extensions.init_views import init_error_handlers, init_plugins
from airflow.providers.fab.www.extensions.init_session import init_airflow_session_interface
from airflow.providers.fab.www.extensions.init_views import (
init_api_auth_provider,
init_api_error_handlers,
init_error_handlers,
init_plugins,
)

app: Flask | None = None

Expand All @@ -58,6 +64,8 @@ def create_app(enable_plugins: bool):
if "SQLALCHEMY_ENGINE_OPTIONS" not in flask_app.config:
flask_app.config["SQLALCHEMY_ENGINE_OPTIONS"] = settings.prepare_engine_args()

csrf.init_app(flask_app)

db = SQLA()
db.session = settings.Session
db.init_app(flask_app)
Expand All @@ -68,11 +76,15 @@ def create_app(enable_plugins: bool):

with flask_app.app_context():
init_appbuilder(flask_app, enable_plugins=enable_plugins)
init_error_handlers(flask_app)
if enable_plugins:
init_plugins(flask_app)
init_error_handlers(flask_app)
else:
init_api_auth_provider(flask_app)
init_api_error_handlers(flask_app)
init_jinja_globals(flask_app, enable_plugins=enable_plugins)
init_xframe_protection(flask_app)
init_airflow_session_interface(flask_app)
return flask_app


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,15 @@ def init_app(self, app, session):
self._add_admin_views()
self._add_addon_views()
self._init_extension(app)
self._swap_url_filter()

def _swap_url_filter(self):
"""Use our url filtering util function so there is consistency between FAB and Airflow routes."""
from flask_appbuilder.security import views as fab_sec_views

from airflow.providers.fab.www.views import get_safe_url

fab_sec_views.get_safe_redirect = get_safe_url

def _init_extension(self, app):
app.appbuilder = self
Expand Down Expand Up @@ -518,6 +527,9 @@ def add_view_no_menu(self, baseview, endpoint=None, static_folder=None):
def get_url_for_index(self):
return url_for(f"{self.indexview.endpoint}.{self.indexview.default_view}")

def get_url_for_login_with(self, next_url: str | None = None) -> str:
return get_auth_manager().get_url_login(next_url=next_url)

def get_url_for_locale(self, lang):
return url_for(
f"{self.bm.locale_view.endpoint}.{self.bm.locale_view.default_view}",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,23 +22,33 @@

from connexion import Resolver
from connexion.decorators.validation import RequestBodyValidator
from connexion.exceptions import BadRequestProblem
from flask import jsonify
from starlette import status
from connexion.exceptions import BadRequestProblem, ProblemException
from flask import request

from airflow.providers.fab.www.api_connexion.exceptions import (
BadRequest,
NotFound,
PermissionDenied,
Unauthenticated,
)
from airflow.api_connexion.exceptions import common_error_handler
from airflow.api_fastapi.app import get_auth_manager

if TYPE_CHECKING:
from flask import Flask

log = logging.getLogger(__name__)


def init_appbuilder_views(app):
"""Initialize Web UI views."""
from airflow.models import import_all_models
from airflow.providers.fab.www import views

import_all_models()

appbuilder = app.appbuilder

# Remove the session from scoped_session registry to avoid
# reusing a session with a disconnected connection
appbuilder.session.remove()
appbuilder.add_view_no_menu(views.FabIndexView())


class _LazyResolution:
"""
OpenAPI endpoint that lazily resolves the function on first use.
Expand Down Expand Up @@ -121,28 +131,47 @@ def init_plugins(app):
app.register_blueprint(blue_print["blueprint"])


def init_error_handlers(app: Flask):
"""Add custom errors handlers."""
base_paths: list[str] = [] # contains the list of base paths that have api endpoints

def handle_bad_request(error):
response = {"error": "Bad request"}
return jsonify(response), status.HTTP_400_BAD_REQUEST

def handle_not_found(error):
response = {"error": "Not found"}
return jsonify(response), status.HTTP_404_NOT_FOUND
def init_api_error_handlers(app: Flask) -> None:
"""Add error handlers for 404 and 405 errors for existing API paths."""
from airflow.providers.fab.www import views

def handle_unauthenticated(error):
response = {"error": "User is not authenticated"}
return jsonify(response), status.HTTP_401_UNAUTHORIZED
@app.errorhandler(404)
def _handle_api_not_found(ex):
if any([request.path.startswith(p) for p in base_paths]):
# 404 errors are never handled on the blueprint level
# unless raised from a view func so actual 404 errors,
# i.e. "no route for it" defined, need to be handled
# here on the application level
return common_error_handler(ex)
else:
return views.not_found(ex)

@app.errorhandler(405)
def _handle_method_not_allowed(ex):
if any([request.path.startswith(p) for p in base_paths]):
return common_error_handler(ex)
else:
return views.method_not_allowed(ex)

app.register_error_handler(ProblemException, common_error_handler)


def init_error_handlers(app: Flask):
"""Add custom errors handlers."""
from airflow.providers.fab.www import views

def handle_denied(error):
response = {"error": "Access is denied"}
return jsonify(response), status.HTTP_403_FORBIDDEN
app.register_error_handler(500, views.show_traceback)
app.register_error_handler(404, views.not_found)

app.register_error_handler(404, handle_not_found)

app.register_error_handler(BadRequest, handle_bad_request)
app.register_error_handler(NotFound, handle_not_found)
app.register_error_handler(Unauthenticated, handle_unauthenticated)
app.register_error_handler(PermissionDenied, handle_denied)
def init_api_auth_provider(app):
"""Initialize the API offered by the auth manager."""
auth_mgr = get_auth_manager()
blueprint = auth_mgr.get_api_endpoints()
if blueprint:
base_paths.append(blueprint.url_prefix)
app.register_blueprint(blueprint)
app.extensions["csrf"].exempt(blueprint)
59 changes: 57 additions & 2 deletions providers/fab/src/airflow/providers/fab/www/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,14 @@

import sys
import traceback
from urllib.parse import unquote, urljoin, urlsplit

from flask import (
g,
redirect,
render_template,
request,
url_for,
)
from flask_appbuilder import IndexView, expose

Expand All @@ -32,6 +35,26 @@
from airflow.utils.net import get_hostname
from airflow.version import version

# Following the release of https://github.com/python/cpython/issues/102153 in Python 3.9.17 on
# June 6, 2023, we are adding extra sanitization of the urls passed to get_safe_url method to make it works
# the same way regardless if the user uses latest Python patchlevel versions or not. This also follows
# a recommended solution by the Python core team.
#
# From: https://github.com/python/cpython/commit/d28bafa2d3e424b6fdcfd7ae7cde8e71d7177369
#
# We recommend that users of these APIs where the values may be used anywhere
# with security implications code defensively. Do some verification within your
# code before trusting a returned component part. Does that ``scheme`` make
# sense? Is that a sensible ``path``? Is there anything strange about that
# ``hostname``? etc.
#
# C0 control and space to be stripped per WHATWG spec.
# == "".join([chr(i) for i in range(0, 0x20 + 1)])
_WHATWG_C0_CONTROL_OR_SPACE = (
"\x00\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c"
"\r\x0e\x0f\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f "
)


class FabIndexView(IndexView):
"""
Expand All @@ -41,8 +64,6 @@ class FabIndexView(IndexView):
authenticated. It is impossible to redirect the user directly to the Airflow 3 UI index page before
redirecting them to this page because FAB itself defines the logic redirection and does not allow external
redirect.

It is impossible to redirect the user before
"""

@expose("/")
Expand Down Expand Up @@ -75,3 +96,37 @@ def show_traceback(error):
),
500,
)


def not_found(error):
"""Show Not Found on screen for any error in the Webserver."""
return (
render_template(
"airflow/error.html",
hostname="",
status_code=404,
error_message="Page cannot be found.",
),
404,
)


def get_safe_url(url):
"""Given a user-supplied URL, ensure it points to our web server."""
if not url:
return url_for("FabIndexView.index")

# If the url contains semicolon, redirect it to homepage to avoid
# potential XSS. (Similar to https://github.com/python/cpython/pull/24297/files (bpo-42967))
if ";" in unquote(url):
return url_for("FabIndexView.index")

url = url.lstrip(_WHATWG_C0_CONTROL_OR_SPACE)

host_url = urlsplit(request.host_url)
redirect_url = urlsplit(urljoin(request.host_url, url))
if not (redirect_url.scheme in ("http", "https") and host_url.netloc == redirect_url.netloc):
return url_for("FabIndexView.index")

# This will ensure we only redirect to the right scheme/netloc
return redirect_url.geturl()
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@
from flask_appbuilder.const import AUTH_LDAP

from airflow.providers.fab.auth_manager.api.auth.backend.basic_auth import requires_authentication
from airflow.www import app as application
from airflow.providers.fab.www import app as application


@pytest.fixture
def app():
return application.create_app(testing=True)
return application.create_app(enable_plugins=False)


@pytest.fixture
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@
from flask import Response

from airflow.providers.fab.auth_manager.api.auth.backend.session import requires_authentication
from airflow.www import app as application
from airflow.providers.fab.www import app as application


@pytest.fixture
def app():
return application.create_app(testing=True)
return application.create_app(enable_plugins=False)


mock_call = Mock()
Expand Down
Loading
Loading