diff --git a/airflow-core/pyproject.toml b/airflow-core/pyproject.toml index 39ec62fc4ec3a..6b0e789459a8a 100644 --- a/airflow-core/pyproject.toml +++ b/airflow-core/pyproject.toml @@ -80,10 +80,6 @@ dependencies = [ # 0.115.10 fastapi was a bad release that broke our API's and static checks. # Related fastapi issue here: https://github.com/fastapi/fastapi/discussions/13431 "fastapi[standard]>=0.115.0,!=0.115.10", - # We could get rid of flask and gunicorn if we replace serve_logs with a starlette + unicorn - "flask>=2.1.1", - # We could get rid of flask and gunicorn if we replace serve_logs with a starlette + unicorn - "gunicorn>=20.1.0", "httpx>=0.25.0", 'importlib_metadata>=6.5;python_version<"3.12"', 'importlib_metadata>=7.0;python_version>="3.12"', diff --git a/airflow-core/src/airflow/utils/serve_logs/__init__.py b/airflow-core/src/airflow/utils/serve_logs/__init__.py new file mode 100644 index 0000000000000..b2468c90cde6b --- /dev/null +++ b/airflow-core/src/airflow/utils/serve_logs/__init__.py @@ -0,0 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from airflow.utils.serve_logs.core import serve_logs +from airflow.utils.serve_logs.log_server import create_app + +__all__ = ["serve_logs", "create_app"] diff --git a/airflow-core/src/airflow/utils/serve_logs/core.py b/airflow-core/src/airflow/utils/serve_logs/core.py new file mode 100644 index 0000000000000..c7d15f0d24009 --- /dev/null +++ b/airflow-core/src/airflow/utils/serve_logs/core.py @@ -0,0 +1,65 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Serve logs process.""" + +from __future__ import annotations + +import logging +import socket +import sys + +import uvicorn + +from airflow.configuration import conf + +logger = logging.getLogger(__name__) + + +def serve_logs(port=None): + """Serve logs generated by Worker.""" + # setproctitle causes issue on Mac OS: https://github.com/benoitc/gunicorn/issues/3021 + os_type = sys.platform + if os_type == "darwin": + logger.debug("Mac OS detected, skipping setproctitle") + else: + from setproctitle import setproctitle + + setproctitle("airflow serve-logs") + + port = port or conf.getint("logging", "WORKER_LOG_SERVER_PORT") + + # If dual stack is available and IPV6_V6ONLY is not enabled on the socket + # then when IPV6 is bound to it will also bind to IPV4 automatically + if getattr(socket, "has_dualstack_ipv6", lambda: False)(): + host = "::" # ASGI uses `::` syntax for IPv6 binding instead of the `[::]` notation used in WSGI, while preserving the `[::]` format in logs + serve_log_uri = f"http://[::]:{port}" + else: + host = "0.0.0.0" + serve_log_uri = f"http://{host}:{port}" + + logger.info("Starting log server on %s", serve_log_uri) + + # Use uvicorn directly for ASGI applications + uvicorn.run("airflow.utils.serve_logs.log_server:app", host=host, port=port, workers=2, log_level="info") + # Note: if we want to use more than 1 workers, we **can't** use the instance of FastAPI directly + # This is way we split the instantiation of log server to a separate module + # + # https://github.com/encode/uvicorn/blob/374bb6764e8d7f34abab0746857db5e3d68ecfdd/docs/deployment/index.md?plain=1#L50-L63 + + +if __name__ == "__main__": + serve_logs() diff --git a/airflow-core/src/airflow/utils/serve_logs.py b/airflow-core/src/airflow/utils/serve_logs/log_server.py similarity index 60% rename from airflow-core/src/airflow/utils/serve_logs.py rename to airflow-core/src/airflow/utils/serve_logs/log_server.py index 899547346b65c..fa0338e5c9f6e 100644 --- a/airflow-core/src/airflow/utils/serve_logs.py +++ b/airflow-core/src/airflow/utils/serve_logs/log_server.py @@ -14,18 +14,16 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Serve logs process.""" +"""Log server written in FastAPI.""" from __future__ import annotations import logging import os -import socket -import sys -from collections import namedtuple +from typing import cast -import gunicorn.app.base -from flask import Flask, abort, request, send_from_directory +from fastapi import FastAPI, HTTPException, Request, status +from fastapi.staticfiles import StaticFiles from jwt.exceptions import ( ExpiredSignatureError, ImmatureSignatureError, @@ -33,7 +31,6 @@ InvalidIssuedAtError, InvalidSignatureError, ) -from werkzeug.exceptions import HTTPException from airflow.api_fastapi.auth.tokens import JWTValidator, get_signing_key from airflow.configuration import conf @@ -43,55 +40,36 @@ logger = logging.getLogger(__name__) -def create_app(): - flask_app = Flask(__name__, static_folder=None) - leeway = conf.getint("webserver", "log_request_clock_grace", fallback=30) - log_directory = os.path.expanduser(conf.get("logging", "BASE_LOG_FOLDER")) - log_config_class = conf.get("logging", "logging_config_class") - if log_config_class: - logger.info("Detected user-defined logging config. Attempting to load %s", log_config_class) - try: - logging_config = import_string(log_config_class) - try: - base_log_folder = logging_config["handlers"]["task"]["base_log_folder"] - except KeyError: - base_log_folder = None - if base_log_folder is not None: - log_directory = base_log_folder - logger.info( - "Successfully imported user-defined logging config. Flask App will serve log from %s", - log_directory, - ) - else: - logger.warning( - "User-defined logging config does not specify 'base_log_folder'. " - "Flask App will use default log directory %s", - base_log_folder, - ) - except Exception as e: - raise ImportError(f"Unable to load {log_config_class} due to error: {e}") - signer = JWTValidator( - issuer=None, - secret_key=get_signing_key("api", "secret_key"), - algorithm="HS512", - leeway=leeway, - audience="task-instance-logs", - ) +class JWTAuthStaticFiles(StaticFiles): + """StaticFiles with JWT authentication.""" - # Prevent direct access to the logs port - @flask_app.before_request - def validate_pre_signed_url(): + # reference from https://github.com/fastapi/fastapi/issues/858#issuecomment-876564020 + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + async def __call__(self, scope, receive, send) -> None: + request = Request(scope, receive) + await self.validate_jwt_token(request) + await super().__call__(scope, receive, send) + + async def validate_jwt_token(self, request: Request): + # we get the signer from the app state instead of creating a new instance for each request + signer = cast("JWTValidator", request.app.state.signer) try: auth = request.headers.get("Authorization") if auth is None: logger.warning("The Authorization header is missing: %s.", request.headers) - abort(403) - payload = signer.validated_claims(auth) + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, detail="Authorization header missing" + ) + payload = await signer.avalidated_claims(auth) token_filename = payload.get("filename") - request_filename = request.view_args["filename"] + # Extract filename from url path + request_filename = request.url.path.lstrip("/log/") if token_filename is None: logger.warning("The payload does not contain 'filename' key: %s.", payload) - abort(403) + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) if token_filename != request_filename: logger.warning( "The payload log_relative_path key is different than the one in token:" @@ -99,18 +77,18 @@ def validate_pre_signed_url(): request_filename, token_filename, ) - abort(403) + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) except HTTPException: raise except InvalidAudienceError: logger.warning("Invalid audience for the request", exc_info=True) - abort(403) + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) except InvalidSignatureError: logger.warning("The signature of the request was wrong", exc_info=True) - abort(403) + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) except ImmatureSignatureError: logger.warning("The signature of the request was sent from the future", exc_info=True) - abort(403) + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) except ExpiredSignatureError: logger.warning( "The signature of the request has expired. Make sure that all components " @@ -119,78 +97,64 @@ def validate_pre_signed_url(): get_docs_url("configurations-ref.html#secret-key"), exc_info=True, ) - abort(403) + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) except InvalidIssuedAtError: logger.warning( - "The request was issues in the future. Make sure that all components " + "The request was issued in the future. Make sure that all components " "in your system have synchronized clocks. " "See more at %s", get_docs_url("configurations-ref.html#secret-key"), exc_info=True, ) - abort(403) + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) except Exception: logger.warning("Unknown error", exc_info=True) - abort(403) - - @flask_app.route("/log/") - def serve_logs_view(filename): - return send_from_directory(log_directory, filename, mimetype="application/json", as_attachment=False) - - return flask_app - - -GunicornOption = namedtuple("GunicornOption", ["key", "value"]) - + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN) -class StandaloneGunicornApplication(gunicorn.app.base.BaseApplication): - """ - Standalone Gunicorn application/serve for usage with any WSGI-application. - Code inspired by an example from the Gunicorn documentation. - https://github.com/benoitc/gunicorn/blob/cf55d2cec277f220ebd605989ce78ad1bb553c46/examples/standalone_app.py - - For details, about standalone gunicorn application, see: - https://docs.gunicorn.org/en/stable/custom.html - """ - - def __init__(self, app, options=None): - self.options = options or [] - self.application = app - super().__init__() - - def load_config(self): - for option in self.options: - self.cfg.set(option.key.lower(), option.value) - - def load(self): - return self.application - - -def serve_logs(port=None): - """Serve logs generated by Worker.""" - # setproctitle causes issue on Mac OS: https://github.com/benoitc/gunicorn/issues/3021 - os_type = sys.platform - if os_type == "darwin": - logger.debug("Mac OS detected, skipping setproctitle") - else: - from setproctitle import setproctitle - - setproctitle("airflow serve-logs") - wsgi_app = create_app() +def create_app(): + leeway = conf.getint("webserver", "log_request_clock_grace", fallback=30) + log_directory = os.path.expanduser(conf.get("logging", "BASE_LOG_FOLDER")) + log_config_class = conf.get("logging", "logging_config_class") + if log_config_class: + logger.info("Detected user-defined logging config. Attempting to load %s", log_config_class) + try: + logging_config = import_string(log_config_class) + try: + base_log_folder = logging_config["handlers"]["task"]["base_log_folder"] + except KeyError: + base_log_folder = None + if base_log_folder is not None: + log_directory = base_log_folder + logger.info( + "Successfully imported user-defined logging config. FastAPI App will serve log from %s", + log_directory, + ) + else: + logger.warning( + "User-defined logging config does not specify 'base_log_folder'. " + "FastAPI App will use default log directory %s", + base_log_folder, + ) + except Exception as e: + raise ImportError(f"Unable to load {log_config_class} due to error: {e}") - port = port or conf.getint("logging", "WORKER_LOG_SERVER_PORT") + fastapi_app = FastAPI() + fastapi_app.state.signer = JWTValidator( + issuer=None, + secret_key=get_signing_key("api", "secret_key"), + algorithm="HS512", + leeway=leeway, + audience="task-instance-logs", + ) - # If dual stack is available and IPV6_V6ONLY is not enabled on the socket - # then when IPV6 is bound to it will also bind to IPV4 automatically - if getattr(socket, "has_dualstack_ipv6", lambda: False)(): - bind_option = GunicornOption("bind", f"[::]:{port}") - else: - bind_option = GunicornOption("bind", f"0.0.0.0:{port}") + fastapi_app.mount( + "/log", + JWTAuthStaticFiles(directory=log_directory, html=False), + name="serve_logs", + ) - options = [bind_option, GunicornOption("workers", 2)] - StandaloneGunicornApplication(wsgi_app, options).run() + return fastapi_app -if __name__ == "__main__": - serve_logs() +app = create_app() diff --git a/airflow-core/tests/unit/utils/test_serve_logs.py b/airflow-core/tests/unit/utils/test_serve_logs.py index 98d19c93b44ff..fbbef2b200046 100644 --- a/airflow-core/tests/unit/utils/test_serve_logs.py +++ b/airflow-core/tests/unit/utils/test_serve_logs.py @@ -18,11 +18,11 @@ from datetime import timedelta from pathlib import Path -from typing import TYPE_CHECKING import jwt import pytest import time_machine +from fastapi.testclient import TestClient from airflow.api_fastapi.auth.tokens import JWTGenerator from airflow.config_templates.airflow_local_settings import DEFAULT_LOGGING_CONFIG @@ -31,9 +31,6 @@ from tests_common.test_utils.config import conf_vars -if TYPE_CHECKING: - from flask.testing import FlaskClient - LOG_DATA = "Airflow log data" * 20 @@ -46,8 +43,7 @@ def client_without_config(tmp_path): } ): app = create_app() - - yield app.test_client() + yield TestClient(app) @pytest.fixture @@ -62,8 +58,7 @@ def client_with_config(): } ): app = create_app() - - yield app.test_client() + yield TestClient(app) @pytest.fixture(params=["client_without_config", "client_with_config"]) @@ -107,20 +102,20 @@ def different_audience(secret_key): @pytest.mark.usefixtures("sample_log") class TestServeLogs: - def test_forbidden_no_auth(self, client: FlaskClient): + def test_forbidden_no_auth(self, client: TestClient): assert client.get("/log/sample.log").status_code == 403 - def test_should_serve_file(self, client: FlaskClient, jwt_generator): + def test_should_serve_file(self, client: TestClient, jwt_generator): response = client.get( "/log/sample.log", headers={ "Authorization": jwt_generator.generate({"filename": "sample.log"}), }, ) - assert response.data.decode() == LOG_DATA + assert response.text == LOG_DATA assert response.status_code == 200 - def test_forbidden_different_logname(self, client: FlaskClient, jwt_generator): + def test_forbidden_different_logname(self, client: TestClient, jwt_generator): response = client.get( "/log/sample.log", headers={ @@ -129,7 +124,7 @@ def test_forbidden_different_logname(self, client: FlaskClient, jwt_generator): ) assert response.status_code == 403 - def test_forbidden_expired(self, client: FlaskClient, jwt_generator): + def test_forbidden_expired(self, client: TestClient, jwt_generator): with time_machine.travel("2010-01-14"): token = jwt_generator.generate({"filename": "sample.log"}) assert ( @@ -142,7 +137,7 @@ def test_forbidden_expired(self, client: FlaskClient, jwt_generator): == 403 ) - def test_forbidden_future(self, client: FlaskClient, jwt_generator): + def test_forbidden_future(self, client: TestClient, jwt_generator): with time_machine.travel(timezone.utcnow() + timedelta(seconds=3600)): token = jwt_generator.generate({"filename": "sample.log"}) assert ( @@ -155,7 +150,7 @@ def test_forbidden_future(self, client: FlaskClient, jwt_generator): == 403 ) - def test_ok_with_short_future_skew(self, client: FlaskClient, jwt_generator): + def test_ok_with_short_future_skew(self, client: TestClient, jwt_generator): print(f"Ts= {timezone.utcnow().timestamp()}") with time_machine.travel(timezone.utcnow() + timedelta(seconds=1)): print(f"Ts with travvel = {timezone.utcnow().timestamp()}") @@ -170,7 +165,7 @@ def test_ok_with_short_future_skew(self, client: FlaskClient, jwt_generator): == 200 ) - def test_ok_with_short_past_skew(self, client: FlaskClient, jwt_generator): + def test_ok_with_short_past_skew(self, client: TestClient, jwt_generator): with time_machine.travel(timezone.utcnow() - timedelta(seconds=31)): token = jwt_generator.generate({"filename": "sample.log"}) assert ( @@ -183,7 +178,7 @@ def test_ok_with_short_past_skew(self, client: FlaskClient, jwt_generator): == 200 ) - def test_forbidden_with_long_future_skew(self, client: FlaskClient, jwt_generator): + def test_forbidden_with_long_future_skew(self, client: TestClient, jwt_generator): with time_machine.travel(timezone.utcnow() + timedelta(seconds=40)): token = jwt_generator.generate({"filename": "sample.log"}) assert ( @@ -196,7 +191,7 @@ def test_forbidden_with_long_future_skew(self, client: FlaskClient, jwt_generato == 403 ) - def test_forbidden_with_long_past_skew(self, client: FlaskClient, jwt_generator): + def test_forbidden_with_long_past_skew(self, client: TestClient, jwt_generator): with time_machine.travel(timezone.utcnow() - timedelta(seconds=40)): token = jwt_generator.generate({"filename": "sample.log"}) assert ( @@ -209,7 +204,7 @@ def test_forbidden_with_long_past_skew(self, client: FlaskClient, jwt_generator) == 403 ) - def test_wrong_audience(self, client: FlaskClient, different_audience): + def test_wrong_audience(self, client: TestClient, different_audience): assert ( client.get( "/log/sample.log", @@ -221,7 +216,7 @@ def test_wrong_audience(self, client: FlaskClient, different_audience): ) @pytest.mark.parametrize("claim_to_remove", ["iat", "exp", "nbf", "aud"]) - def test_missing_claims(self, claim_to_remove: str, client: FlaskClient, secret_key): + def test_missing_claims(self, claim_to_remove: str, client: TestClient, secret_key): jwt_dict = { "aud": "task-instance-logs", "iat": timezone.utcnow(), diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index dee2f47ddb6d3..62b5b70fa9bc0 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -788,6 +788,7 @@ grpc GSoD gsuite Gunicorn +gunicorn gz Gzip gzipped