Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions airflow/executors/celery_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@
from celery.backends.base import BaseKeyValueStoreBackend
from celery.backends.database import DatabaseBackend, Task as TaskDb, session_cleanup
from celery.result import AsyncResult
from setproctitle import setproctitle # pylint: disable=no-name-in-module

import airflow.settings as settings
from airflow.config_templates.default_celery import DEFAULT_CELERY_CONFIG
from airflow.configuration import conf
from airflow.exceptions import AirflowException
Expand Down Expand Up @@ -78,6 +80,45 @@ def execute_command(command_to_exec: CommandType) -> None:
"""Executes command."""
BaseExecutor.validate_command(command_to_exec)
log.info("Executing command in Celery: %s", command_to_exec)

if settings.EXECUTE_TASKS_NEW_PYTHON_INTERPRETER:
_execute_in_subprocees(command_to_exec)
else:
_execute_in_fork(command_to_exec)


def _execute_in_fork(command_to_exec: CommandType) -> None:
pid = os.fork()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to fork it? Shouldn't we just execute it in current process (celery worker process)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't cos of the logging.shutdown() at the end of task_run (which we need to keep, as that's when remote logs are uploaded. #11327 (comment)

if pid:
# In parent, wait for the child
pid, ret = os.waitpid(pid, 0)
if ret == 0:
return

raise AirflowException('Celery command failed on host: ' + get_hostname())

from airflow.sentry import Sentry

ret = 1
try:
from airflow.cli.cli_parser import get_parser
parser = get_parser()
# [1:] - remove "airflow" from the start of the command
args = parser.parse_args(command_to_exec[1:])

setproctitle(f"airflow task supervisor: {command_to_exec}")

args.func(args)
ret = 0
except Exception as e: # pylint: disable=broad-except
log.error("Failed to execute task %s.", str(e))
ret = 1
finally:
Sentry.flush()
os._exit(ret) # pylint: disable=protected-access


def _execute_in_subprocees(command_to_exec: CommandType) -> None:
env = os.environ.copy()
try:
# pylint: disable=unexpected-keyword-arg
Expand Down
22 changes: 12 additions & 10 deletions tests/executors/test_celery_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,18 +225,20 @@ def test_gauge_executor_metrics(self, mock_stats_gauge, mock_trigger_tasks, mock
[['airflow', 'version'], ValueError],
[['airflow', 'tasks', 'run'], None]
))
@mock.patch('subprocess.check_output')
def test_command_validation(self, command, expected_exception, mock_check_output):
def test_command_validation(self, command, expected_exception):
# Check that we validate _on the receiving_ side, not just sending side
if expected_exception:
with pytest.raises(expected_exception):
with mock.patch('airflow.executors.celery_executor._execute_in_subprocees') as mock_subproc, \
mock.patch('airflow.executors.celery_executor._execute_in_fork') as mock_fork:
if expected_exception:
with pytest.raises(expected_exception):
celery_executor.execute_command(command)
mock_subproc.assert_not_called()
mock_fork.assert_not_called()
else:
celery_executor.execute_command(command)
mock_check_output.assert_not_called()
else:
celery_executor.execute_command(command)
mock_check_output.assert_called_once_with(
command, stderr=mock.ANY, close_fds=mock.ANY, env=mock.ANY,
)
# One of these should be called.
assert mock_subproc.call_args == ((command,),) or \
mock_fork.call_args == ((command,),)

@pytest.mark.backend("mysql", "postgres")
def test_try_adopt_task_instances_none(self):
Expand Down