Skip to content

Commit

Permalink
Manual triggering task with celery executor fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Jorricks committed Sep 14, 2021
1 parent c73004d commit 0cbe996
Show file tree
Hide file tree
Showing 6 changed files with 96 additions and 9 deletions.
9 changes: 8 additions & 1 deletion airflow/cli/commands/task_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import os
import textwrap
from contextlib import contextmanager, redirect_stderr, redirect_stdout, suppress
from typing import List
from typing import List, Optional

from pendulum.parsing.exceptions import ParserError
from sqlalchemy.orm.exc import NoResultFound
Expand Down Expand Up @@ -157,6 +157,7 @@ def _run_task_by_local_task_job(args, ti):
ignore_task_deps=args.ignore_dependencies,
ignore_ti_state=args.force,
pool=args.pool,
external_executor_id=_extract_external_executor_id(args),
)
try:
run_job.run()
Expand Down Expand Up @@ -184,6 +185,12 @@ def _run_raw_task(args, ti: TaskInstance) -> None:
)


def _extract_external_executor_id(args) -> Optional[str]:
if hasattr(args, "external_executor_id"):
return getattr(args, "external_executor_id")
return os.environ.get("external_executor_id", None)


@contextmanager
def _capture_task_logs(ti):
"""Manage logging context for a task run
Expand Down
16 changes: 10 additions & 6 deletions airflow/executors/celery_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,14 +81,16 @@ def execute_command(command_to_exec: CommandType) -> None:
"""Executes command."""
BaseExecutor.validate_command(command_to_exec)
log.info("Executing command in Celery: %s", command_to_exec)
celery_task_id = app.current_task.request.id
log.info(f"Celery task ID: {celery_task_id}")

if settings.EXECUTE_TASKS_NEW_PYTHON_INTERPRETER:
_execute_in_subprocess(command_to_exec)
_execute_in_subprocess(command_to_exec, celery_task_id)
else:
_execute_in_fork(command_to_exec)
_execute_in_fork(command_to_exec, celery_task_id)


def _execute_in_fork(command_to_exec: CommandType) -> None:
def _execute_in_fork(command_to_exec: CommandType, celery_task_id: Optional[str] = None) -> None:
pid = os.fork()
if pid:
# In parent, wait for the child
Expand All @@ -111,6 +113,8 @@ def _execute_in_fork(command_to_exec: CommandType) -> None:
# [1:] - remove "airflow" from the start of the command
args = parser.parse_args(command_to_exec[1:])
args.shut_down_logging = False
if celery_task_id:
args.external_executor_id = celery_task_id

setproctitle(f"airflow task supervisor: {command_to_exec}")

Expand All @@ -125,12 +129,12 @@ def _execute_in_fork(command_to_exec: CommandType) -> None:
os._exit(ret)


def _execute_in_subprocess(command_to_exec: CommandType) -> None:
def _execute_in_subprocess(command_to_exec: CommandType, celery_task_id: Optional[str] = None) -> None:
env = os.environ.copy()
if celery_task_id:
env["external_executor_id"] = celery_task_id
try:

subprocess.check_output(command_to_exec, stderr=subprocess.STDOUT, close_fds=True, env=env)

except subprocess.CalledProcessError as e:
log.exception('execute_command encountered a CalledProcessError')
log.error(e.output)
Expand Down
3 changes: 3 additions & 0 deletions airflow/jobs/local_task_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def __init__(
mark_success: bool = False,
pickle_id: Optional[str] = None,
pool: Optional[str] = None,
external_executor_id: Optional[str] = None,
*args,
**kwargs,
):
Expand All @@ -64,6 +65,7 @@ def __init__(
self.pool = pool
self.pickle_id = pickle_id
self.mark_success = mark_success
self.external_executor_id = external_executor_id
self.task_runner = None

# terminating state is used so that a job don't try to
Expand Down Expand Up @@ -92,6 +94,7 @@ def signal_handler(signum, frame):
ignore_ti_state=self.ignore_ti_state,
job_id=self.id,
pool=self.pool,
external_executor_id=self.external_executor_id,
):
self.log.info("Task is not able to be run")
return
Expand Down
4 changes: 4 additions & 0 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -1128,6 +1128,7 @@ def check_and_change_state_before_execution(
test_mode: bool = False,
job_id: Optional[str] = None,
pool: Optional[str] = None,
external_executor_id: Optional[str] = None,
session=None,
) -> bool:
"""
Expand All @@ -1153,6 +1154,8 @@ def check_and_change_state_before_execution(
:type job_id: str
:param pool: specifies the pool to use to run the task instance
:type pool: str
:param external_executor_id: The identifier of the celery executor
:type external_executor_id: str
:param session: SQLAlchemy ORM Session
:type session: Session
:return: whether the state was changed to running or not
Expand Down Expand Up @@ -1234,6 +1237,7 @@ def check_and_change_state_before_execution(
if not test_mode:
session.add(Log(State.RUNNING, self))
self.state = State.RUNNING
self.external_executor_id = external_executor_id
self.end_date = None
if not test_mode:
session.merge(self)
Expand Down
62 changes: 62 additions & 0 deletions tests/cli/commands/test_task_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def test_run_with_existing_dag_run_id(self, mock_local_job):
ignore_ti_state=False,
pickle_id=None,
pool=None,
external_executor_id=None,
)

@mock.patch("airflow.cli.commands.task_command.LocalTaskJob")
Expand Down Expand Up @@ -435,6 +436,67 @@ def assert_log_line(self, text, logs_list, expect_from_logging_mixin=False):
assert "logging_mixin.py" not in log_line
return log_line

@mock.patch("airflow.cli.commands.task_command.LocalTaskJob")
def test_external_executor_id_present_for_fork_run_task(self, mock_local_job):
naive_date = datetime(2016, 1, 1)
dag_id = 'test_run_fork_has_external_executor_id'
task0_id = 'test_run_fork_task'

dag = self.dagbag.get_dag(dag_id)
args_list = [
'tasks',
'run',
'--local',
dag_id,
task0_id,
naive_date.isoformat(),
]
args = self.parser.parse_args(args_list)
args.external_executor_id = "ABCD12345"

task_command.task_run(args, dag=dag)
mock_local_job.assert_called_once_with(
task_instance=mock.ANY,
mark_success=False,
pickle_id=None,
ignore_all_deps=False,
ignore_depends_on_past=False,
ignore_task_deps=False,
ignore_ti_state=False,
pool=None,
external_executor_id="ABCD12345",
)

@mock.patch("airflow.cli.commands.task_command.LocalTaskJob")
def test_external_executor_id_present_for_process_run_task(self, mock_local_job):
naive_date = datetime(2016, 1, 1)
dag_id = 'test_run_process_has_external_executor_id'
task0_id = 'test_run_process_task'

dag = self.dagbag.get_dag(dag_id)
args_list = [
'tasks',
'run',
'--local',
dag_id,
task0_id,
naive_date.isoformat(),
]
args = self.parser.parse_args(args_list)
with mock.patch.dict(os.environ, {"external_executor_id": "12345FEDCBA"}):
task_command.task_run(args, dag=dag)
mock_local_job.assert_called_once_with(
task_instance=mock.ANY,
mark_success=False,
pickle_id=None,
ignore_all_deps=False,
ignore_depends_on_past=False,
ignore_task_deps=False,
ignore_ti_state=False,
pool=None,
external_executor_id="ABCD12345",
)

@unittest.skipIf(not hasattr(os, 'fork'), "Forking not available")
def test_logging_with_run_task(self):
# We are not using self.assertLogs as we want to verify what actually is stored in the Log file
Expand Down
11 changes: 9 additions & 2 deletions tests/executors/test_celery_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,12 @@ def test_command_validation(self, command, expected_exception):
# Check that we validate _on the receiving_ side, not just sending side
with mock.patch(
'airflow.executors.celery_executor._execute_in_subprocess'
) as mock_subproc, mock.patch('airflow.executors.celery_executor._execute_in_fork') as mock_fork:
) as mock_subproc, mock.patch(
'airflow.executors.celery_executor._execute_in_fork'
) as mock_fork, mock.patch(
"celery.app.task.Task.request"
) as mock_task:
mock_task.id = "abcdef-124215-abcdef"
if expected_exception:
with pytest.raises(expected_exception):
celery_executor.execute_command(command)
Expand All @@ -296,7 +301,9 @@ def test_command_validation(self, command, expected_exception):
else:
celery_executor.execute_command(command)
# One of these should be called.
assert mock_subproc.call_args == ((command,),) or mock_fork.call_args == ((command,),)
assert mock_subproc.call_args == (
(command, "abcdef-124215-abcdef"),
) or mock_fork.call_args == ((command, "abcdef-124215-abcdef"),)

@pytest.mark.backend("mysql", "postgres")
def test_try_adopt_task_instances_none(self):
Expand Down

0 comments on commit 0cbe996

Please sign in to comment.