Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor commands to unify daemon context handling #34945

Merged
merged 25 commits into from
Oct 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,7 @@ repos:
^airflow/api_connexion/openapi/v1.yaml$|
^airflow/auth/managers/fab/security_manager/|
^airflow/cli/commands/webserver_command.py$|
^airflow/cli/commands/internal_api_command.py$|
^airflow/config_templates/|
^airflow/models/baseoperator.py$|
^airflow/operators/__init__.py$|
Expand Down
89 changes: 31 additions & 58 deletions airflow/cli/commands/celery_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,18 @@
from contextlib import contextmanager
from multiprocessing import Process

import daemon
import psutil
import sqlalchemy.exc
from celery import maybe_patch_concurrency # type: ignore[attr-defined]
from celery.app.defaults import DEFAULT_TASK_LOG_FMT
from celery.signals import after_setup_logger
from daemon.pidfile import TimeoutPIDLockFile
from lockfile.pidlockfile import read_pid_from_pidfile, remove_existing_pidfile

from airflow import settings
from airflow.cli.commands.daemon_utils import run_command_with_daemon_option
from airflow.configuration import conf
from airflow.utils import cli as cli_utils
from airflow.utils.cli import setup_locations, setup_logging
from airflow.utils.cli import setup_locations
from airflow.utils.providers_configuration_loader import providers_configuration_loaded
from airflow.utils.serve_logs import serve_logs

Expand Down Expand Up @@ -68,28 +67,9 @@ def flower(args):
if args.flower_conf:
options.append(f"--conf={args.flower_conf}")

if args.daemon:
pidfile, stdout, stderr, _ = setup_locations(
process="flower",
pid=args.pid,
stdout=args.stdout,
stderr=args.stderr,
log=args.log_file,
)
with open(stdout, "a") as stdout, open(stderr, "a") as stderr:
stdout.truncate(0)
stderr.truncate(0)

ctx = daemon.DaemonContext(
pidfile=TimeoutPIDLockFile(pidfile, -1),
stdout=stdout,
stderr=stderr,
umask=int(settings.DAEMON_UMASK, 8),
)
with ctx:
celery_app.start(options)
else:
celery_app.start(options)
run_command_with_daemon_option(
args=args, process_name="flower", callback=lambda: celery_app.start(options)
)


@contextmanager
Expand Down Expand Up @@ -152,15 +132,6 @@ def worker(args):
if autoscale is None and conf.has_option("celery", "worker_autoscale"):
autoscale = conf.get("celery", "worker_autoscale")

# Setup locations
pid_file_path, stdout, stderr, log_file = setup_locations(
process=WORKER_PROCESS_NAME,
pid=args.pid,
stdout=args.stdout,
stderr=args.stderr,
log=args.log_file,
)

if hasattr(celery_app.backend, "ResultSession"):
# Pre-create the database tables now, otherwise SQLA via Celery has a
# race condition where one of the subprocesses can die with "Table
Expand All @@ -181,6 +152,10 @@ def worker(args):
celery_log_level = conf.get("logging", "CELERY_LOGGING_LEVEL")
if not celery_log_level:
celery_log_level = conf.get("logging", "LOGGING_LEVEL")

# Setup pid file location
worker_pid_file_path, _, _, _ = setup_locations(process=WORKER_PROCESS_NAME, pid=args.pid)

# Setup Celery worker
options = [
"worker",
Expand All @@ -195,7 +170,7 @@ def worker(args):
"--loglevel",
celery_log_level,
"--pidfile",
pid_file_path,
worker_pid_file_path,
]
if autoscale:
options.extend(["--autoscale", autoscale])
Expand All @@ -214,33 +189,31 @@ def worker(args):
# executed.
maybe_patch_concurrency(["-P", pool])

if args.daemon:
# Run Celery worker as daemon
handle = setup_logging(log_file)

with open(stdout, "a") as stdout_handle, open(stderr, "a") as stderr_handle:
if args.umask:
umask = args.umask
else:
umask = conf.get("celery", "worker_umask", fallback=settings.DAEMON_UMASK)

stdout_handle.truncate(0)
stderr_handle.truncate(0)

daemon_context = daemon.DaemonContext(
files_preserve=[handle],
umask=int(umask, 8),
stdout=stdout_handle,
stderr=stderr_handle,
)
with daemon_context, _serve_logs(skip_serve_logs):
celery_app.worker_main(options)
_, stdout, stderr, log_file = setup_locations(
process=WORKER_PROCESS_NAME,
stdout=args.stdout,
stderr=args.stderr,
log=args.log_file,
)

else:
# Run Celery worker in the same process
def run_celery_worker():
with _serve_logs(skip_serve_logs):
celery_app.worker_main(options)

if args.umask:
umask = args.umask
else:
umask = conf.get("celery", "worker_umask", fallback=settings.DAEMON_UMASK)

run_command_with_daemon_option(
args=args,
process_name=WORKER_PROCESS_NAME,
callback=run_celery_worker,
should_setup_logging=True,
umask=umask,
pid_file=worker_pid_file_path,
)


@cli_utils.action_cli
@providers_configuration_loaded
Expand Down
82 changes: 82 additions & 0 deletions airflow/cli/commands/daemon_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import signal
from argparse import Namespace
from typing import Callable

from daemon import daemon
from daemon.pidfile import TimeoutPIDLockFile

from airflow import settings
from airflow.utils.cli import setup_locations, setup_logging, sigint_handler, sigquit_handler
from airflow.utils.process_utils import check_if_pidfile_process_is_running


def run_command_with_daemon_option(
*,
args: Namespace,
process_name: str,
callback: Callable,
should_setup_logging: bool = False,
umask: str = settings.DAEMON_UMASK,
pid_file: str | None = None,
):
"""Run the command in a daemon process if daemon mode enabled or within this process if not.

:param args: the set of arguments passed to the original CLI command
:param process_name: process name used in naming log and PID files for the daemon
:param callback: the actual command to run with or without daemon context
:param should_setup_logging: if true, then a log file handler for the daemon process will be created
:param umask: file access creation mask ("umask") to set for the process on daemon start
:param pid_file: if specified, this file path us used to store daemon process PID.
If not specified, a file path is generated with the default pattern.
"""
if args.daemon:
pid, stdout, stderr, log_file = setup_locations(
process=process_name, stdout=args.stdout, stderr=args.stderr, log=args.log_file
)
if pid_file:
pid = pid_file

# Check if the process is already running; if not but a pidfile exists, clean it up
check_if_pidfile_process_is_running(pid_file=pid, process_name=process_name)

if should_setup_logging:
files_preserve = [setup_logging(log_file)]
else:
files_preserve = None
with open(stdout, "a") as stdout_handle, open(stderr, "a") as stderr_handle:
stdout_handle.truncate(0)
stderr_handle.truncate(0)

ctx = daemon.DaemonContext(
pidfile=TimeoutPIDLockFile(pid, -1),
files_preserve=files_preserve,
stdout=stdout_handle,
stderr=stderr_handle,
umask=int(umask, 8),
)

with ctx:
callback()
else:
signal.signal(signal.SIGINT, sigint_handler)
signal.signal(signal.SIGTERM, sigint_handler)
signal.signal(signal.SIGQUIT, sigquit_handler)
callback()
32 changes: 7 additions & 25 deletions airflow/cli/commands/dag_processor_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,12 @@
from datetime import timedelta
from typing import Any

import daemon
from daemon.pidfile import TimeoutPIDLockFile

from airflow import settings
from airflow.cli.commands.daemon_utils import run_command_with_daemon_option
from airflow.configuration import conf
from airflow.dag_processing.manager import DagFileProcessorManager
from airflow.jobs.dag_processor_job_runner import DagProcessorJobRunner
from airflow.jobs.job import Job, run_job
from airflow.utils import cli as cli_utils
from airflow.utils.cli import setup_locations, setup_logging
from airflow.utils.providers_configuration_loader import providers_configuration_loaded

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -66,23 +62,9 @@ def dag_processor(args):

job_runner = _create_dag_processor_job_runner(args)

if args.daemon:
pid, stdout, stderr, log_file = setup_locations(
"dag-processor", args.pid, args.stdout, args.stderr, args.log_file
)
handle = setup_logging(log_file)
with open(stdout, "a") as stdout_handle, open(stderr, "a") as stderr_handle:
stdout_handle.truncate(0)
stderr_handle.truncate(0)

ctx = daemon.DaemonContext(
pidfile=TimeoutPIDLockFile(pid, -1),
files_preserve=[handle],
stdout=stdout_handle,
stderr=stderr_handle,
umask=int(settings.DAEMON_UMASK, 8),
)
with ctx:
run_job(job=job_runner.job, execute_callable=job_runner._execute)
else:
run_job(job=job_runner.job, execute_callable=job_runner._execute)
run_command_with_daemon_option(
args=args,
process_name="dag-processor",
callback=lambda: run_job(job=job_runner.job, execute_callable=job_runner._execute),
should_setup_logging=True,
)
Loading