diff --git a/airflow/cli/commands/task_command.py b/airflow/cli/commands/task_command.py index d8ae8c4f47dce..ffba0e1e24a8c 100644 --- a/airflow/cli/commands/task_command.py +++ b/airflow/cli/commands/task_command.py @@ -22,7 +22,7 @@ import os import textwrap from contextlib import contextmanager, redirect_stderr, redirect_stdout, suppress -from typing import List +from typing import List, Optional from pendulum.parsing.exceptions import ParserError from sqlalchemy.orm.exc import NoResultFound @@ -157,6 +157,7 @@ def _run_task_by_local_task_job(args, ti): ignore_task_deps=args.ignore_dependencies, ignore_ti_state=args.force, pool=args.pool, + external_executor_id=_extract_external_executor_id(args), ) try: run_job.run() @@ -184,6 +185,12 @@ def _run_raw_task(args, ti: TaskInstance) -> None: ) +def _extract_external_executor_id(args) -> Optional[str]: + if hasattr(args, "external_executor_id"): + return getattr(args, "external_executor_id") + return os.environ.get("external_executor_id", None) + + @contextmanager def _capture_task_logs(ti): """Manage logging context for a task run diff --git a/airflow/executors/celery_executor.py b/airflow/executors/celery_executor.py index 56edb6e2d5de2..40a94f84ddf21 100644 --- a/airflow/executors/celery_executor.py +++ b/airflow/executors/celery_executor.py @@ -81,14 +81,16 @@ def execute_command(command_to_exec: CommandType) -> None: """Executes command.""" BaseExecutor.validate_command(command_to_exec) log.info("Executing command in Celery: %s", command_to_exec) + celery_task_id = app.current_task.request.id + log.info(f"Celery task ID: {celery_task_id}") if settings.EXECUTE_TASKS_NEW_PYTHON_INTERPRETER: - _execute_in_subprocess(command_to_exec) + _execute_in_subprocess(command_to_exec, celery_task_id) else: - _execute_in_fork(command_to_exec) + _execute_in_fork(command_to_exec, celery_task_id) -def _execute_in_fork(command_to_exec: CommandType) -> None: +def _execute_in_fork(command_to_exec: CommandType, celery_task_id: Optional[str] = None) -> None: pid = os.fork() if pid: # In parent, wait for the child @@ -111,6 +113,8 @@ def _execute_in_fork(command_to_exec: CommandType) -> None: # [1:] - remove "airflow" from the start of the command args = parser.parse_args(command_to_exec[1:]) args.shut_down_logging = False + if celery_task_id: + args.external_executor_id = celery_task_id setproctitle(f"airflow task supervisor: {command_to_exec}") @@ -125,12 +129,12 @@ def _execute_in_fork(command_to_exec: CommandType) -> None: os._exit(ret) -def _execute_in_subprocess(command_to_exec: CommandType) -> None: +def _execute_in_subprocess(command_to_exec: CommandType, celery_task_id: Optional[str] = None) -> None: env = os.environ.copy() + if celery_task_id: + env["external_executor_id"] = celery_task_id try: - subprocess.check_output(command_to_exec, stderr=subprocess.STDOUT, close_fds=True, env=env) - except subprocess.CalledProcessError as e: log.exception('execute_command encountered a CalledProcessError') log.error(e.output) diff --git a/airflow/jobs/local_task_job.py b/airflow/jobs/local_task_job.py index 7878216c7a9ba..6b53bd9357566 100644 --- a/airflow/jobs/local_task_job.py +++ b/airflow/jobs/local_task_job.py @@ -52,6 +52,7 @@ def __init__( mark_success: bool = False, pickle_id: Optional[str] = None, pool: Optional[str] = None, + external_executor_id: Optional[str] = None, *args, **kwargs, ): @@ -64,6 +65,7 @@ def __init__( self.pool = pool self.pickle_id = pickle_id self.mark_success = mark_success + self.external_executor_id = external_executor_id self.task_runner = None # terminating state is used so that a job don't try to @@ -92,6 +94,7 @@ def signal_handler(signum, frame): ignore_ti_state=self.ignore_ti_state, job_id=self.id, pool=self.pool, + external_executor_id=self.external_executor_id, ): self.log.info("Task is not able to be run") return diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 1e6802e1a49a1..9dca1f3e71c3c 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -1128,6 +1128,7 @@ def check_and_change_state_before_execution( test_mode: bool = False, job_id: Optional[str] = None, pool: Optional[str] = None, + external_executor_id: Optional[str] = None, session=None, ) -> bool: """ @@ -1153,6 +1154,8 @@ def check_and_change_state_before_execution( :type job_id: str :param pool: specifies the pool to use to run the task instance :type pool: str + :param external_executor_id: The identifier of the celery executor + :type external_executor_id: str :param session: SQLAlchemy ORM Session :type session: Session :return: whether the state was changed to running or not @@ -1234,6 +1237,7 @@ def check_and_change_state_before_execution( if not test_mode: session.add(Log(State.RUNNING, self)) self.state = State.RUNNING + self.external_executor_id = external_executor_id self.end_date = None if not test_mode: session.merge(self) diff --git a/tests/cli/commands/test_task_command.py b/tests/cli/commands/test_task_command.py index 7cb98d3be4ccc..7d246c732dafe 100644 --- a/tests/cli/commands/test_task_command.py +++ b/tests/cli/commands/test_task_command.py @@ -134,6 +134,7 @@ def test_run_with_existing_dag_run_id(self, mock_local_job): ignore_ti_state=False, pickle_id=None, pool=None, + external_executor_id=None, ) @mock.patch("airflow.cli.commands.task_command.LocalTaskJob") @@ -435,6 +436,67 @@ def assert_log_line(self, text, logs_list, expect_from_logging_mixin=False): assert "logging_mixin.py" not in log_line return log_line + @mock.patch("airflow.cli.commands.task_command.LocalTaskJob") + def test_external_executor_id_present_for_fork_run_task(self, mock_local_job): + naive_date = datetime(2016, 1, 1) + dag_id = 'test_run_fork_has_external_executor_id' + task0_id = 'test_run_fork_task' + + dag = self.dagbag.get_dag(dag_id) + args_list = [ + 'tasks', + 'run', + '--local', + dag_id, + task0_id, + naive_date.isoformat(), + ] + args = self.parser.parse_args(args_list) + args.external_executor_id = "ABCD12345" + + task_command.task_run(args, dag=dag) + mock_local_job.assert_called_once_with( + task_instance=mock.ANY, + mark_success=False, + pickle_id=None, + ignore_all_deps=False, + ignore_depends_on_past=False, + ignore_task_deps=False, + ignore_ti_state=False, + pool=None, + external_executor_id="ABCD12345", + ) + + @mock.patch("airflow.cli.commands.task_command.LocalTaskJob") + def test_external_executor_id_present_for_process_run_task(self, mock_local_job): + naive_date = datetime(2016, 1, 1) + dag_id = 'test_run_process_has_external_executor_id' + task0_id = 'test_run_process_task' + + dag = self.dagbag.get_dag(dag_id) + args_list = [ + 'tasks', + 'run', + '--local', + dag_id, + task0_id, + naive_date.isoformat(), + ] + args = self.parser.parse_args(args_list) + with mock.patch.dict(os.environ, {"external_executor_id": "12345FEDCBA"}): + task_command.task_run(args, dag=dag) + mock_local_job.assert_called_once_with( + task_instance=mock.ANY, + mark_success=False, + pickle_id=None, + ignore_all_deps=False, + ignore_depends_on_past=False, + ignore_task_deps=False, + ignore_ti_state=False, + pool=None, + external_executor_id="ABCD12345", + ) + @unittest.skipIf(not hasattr(os, 'fork'), "Forking not available") def test_logging_with_run_task(self): # We are not using self.assertLogs as we want to verify what actually is stored in the Log file diff --git a/tests/executors/test_celery_executor.py b/tests/executors/test_celery_executor.py index 4ab5a105ae141..636d49ddb151d 100644 --- a/tests/executors/test_celery_executor.py +++ b/tests/executors/test_celery_executor.py @@ -287,7 +287,12 @@ def test_command_validation(self, command, expected_exception): # Check that we validate _on the receiving_ side, not just sending side with mock.patch( 'airflow.executors.celery_executor._execute_in_subprocess' - ) as mock_subproc, mock.patch('airflow.executors.celery_executor._execute_in_fork') as mock_fork: + ) as mock_subproc, mock.patch( + 'airflow.executors.celery_executor._execute_in_fork' + ) as mock_fork, mock.patch( + "celery.app.task.Task.request" + ) as mock_task: + mock_task.id = "abcdef-124215-abcdef" if expected_exception: with pytest.raises(expected_exception): celery_executor.execute_command(command) @@ -296,7 +301,9 @@ def test_command_validation(self, command, expected_exception): else: celery_executor.execute_command(command) # One of these should be called. - assert mock_subproc.call_args == ((command,),) or mock_fork.call_args == ((command,),) + assert mock_subproc.call_args == ( + (command, "abcdef-124215-abcdef"), + ) or mock_fork.call_args == ((command, "abcdef-124215-abcdef"),) @pytest.mark.backend("mysql", "postgres") def test_try_adopt_task_instances_none(self):