diff --git a/airflow/executors/celery_executor.py b/airflow/executors/celery_executor.py index d72831490e2bf..f77ea91a5e8fa 100644 --- a/airflow/executors/celery_executor.py +++ b/airflow/executors/celery_executor.py @@ -37,7 +37,9 @@ from celery.backends.base import BaseKeyValueStoreBackend from celery.backends.database import DatabaseBackend, Task as TaskDb, session_cleanup from celery.result import AsyncResult +from setproctitle import setproctitle # pylint: disable=no-name-in-module +import airflow.settings as settings from airflow.config_templates.default_celery import DEFAULT_CELERY_CONFIG from airflow.configuration import conf from airflow.exceptions import AirflowException @@ -78,6 +80,45 @@ def execute_command(command_to_exec: CommandType) -> None: """Executes command.""" BaseExecutor.validate_command(command_to_exec) log.info("Executing command in Celery: %s", command_to_exec) + + if settings.EXECUTE_TASKS_NEW_PYTHON_INTERPRETER: + _execute_in_subprocees(command_to_exec) + else: + _execute_in_fork(command_to_exec) + + +def _execute_in_fork(command_to_exec: CommandType) -> None: + pid = os.fork() + if pid: + # In parent, wait for the child + pid, ret = os.waitpid(pid, 0) + if ret == 0: + return + + raise AirflowException('Celery command failed on host: ' + get_hostname()) + + from airflow.sentry import Sentry + + ret = 1 + try: + from airflow.cli.cli_parser import get_parser + parser = get_parser() + # [1:] - remove "airflow" from the start of the command + args = parser.parse_args(command_to_exec[1:]) + + setproctitle(f"airflow task supervisor: {command_to_exec}") + + args.func(args) + ret = 0 + except Exception as e: # pylint: disable=broad-except + log.error("Failed to execute task %s.", str(e)) + ret = 1 + finally: + Sentry.flush() + os._exit(ret) # pylint: disable=protected-access + + +def _execute_in_subprocees(command_to_exec: CommandType) -> None: env = os.environ.copy() try: # pylint: disable=unexpected-keyword-arg diff --git a/tests/executors/test_celery_executor.py b/tests/executors/test_celery_executor.py index 7408106dc8bfa..fcf0e3603fedc 100644 --- a/tests/executors/test_celery_executor.py +++ b/tests/executors/test_celery_executor.py @@ -225,18 +225,20 @@ def test_gauge_executor_metrics(self, mock_stats_gauge, mock_trigger_tasks, mock [['airflow', 'version'], ValueError], [['airflow', 'tasks', 'run'], None] )) - @mock.patch('subprocess.check_output') - def test_command_validation(self, command, expected_exception, mock_check_output): + def test_command_validation(self, command, expected_exception): # Check that we validate _on the receiving_ side, not just sending side - if expected_exception: - with pytest.raises(expected_exception): + with mock.patch('airflow.executors.celery_executor._execute_in_subprocees') as mock_subproc, \ + mock.patch('airflow.executors.celery_executor._execute_in_fork') as mock_fork: + if expected_exception: + with pytest.raises(expected_exception): + celery_executor.execute_command(command) + mock_subproc.assert_not_called() + mock_fork.assert_not_called() + else: celery_executor.execute_command(command) - mock_check_output.assert_not_called() - else: - celery_executor.execute_command(command) - mock_check_output.assert_called_once_with( - command, stderr=mock.ANY, close_fds=mock.ANY, env=mock.ANY, - ) + # One of these should be called. + assert mock_subproc.call_args == ((command,),) or \ + mock_fork.call_args == ((command,),) @pytest.mark.backend("mysql", "postgres") def test_try_adopt_task_instances_none(self):