Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Send important executor logs to task logs #40468

Merged
merged 21 commits into from
Jul 10, 2024
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions airflow/executors/base_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from airflow.exceptions import RemovedInAirflow3Warning
from airflow.stats import Stats
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.log.task_context_logger import TaskContextLogger
from airflow.utils.state import TaskInstanceState

PARALLELISM: int = conf.getint("core", "PARALLELISM")
Expand Down Expand Up @@ -122,6 +123,7 @@ class BaseExecutor(LoggingMixin):
job_id: None | int | str = None
name: None | ExecutorName = None
callback_sink: BaseCallbackSink | None = None
task_context_logger: TaskContextLogger
vincbeck marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, parallelism: int = PARALLELISM):
super().__init__()
Expand All @@ -130,6 +132,10 @@ def __init__(self, parallelism: int = PARALLELISM):
self.running: set[TaskInstanceKey] = set()
self.event_buffer: dict[TaskInstanceKey, EventBufferValueType] = {}
self.attempts: dict[TaskInstanceKey, RunningRetryAttemptType] = defaultdict(RunningRetryAttemptType)
self.task_context_logger: TaskContextLogger = TaskContextLogger(
component_name="Executor",
vincbeck marked this conversation as resolved.
Show resolved Hide resolved
call_site_logger=self.log,
)

def __repr__(self):
return f"{self.__class__.__name__}(parallelism={self.parallelism})"
Expand All @@ -149,7 +155,7 @@ def queue_command(
self.log.info("Adding to queue: %s", command)
self.queued_tasks[task_instance.key] = (command, priority, queue, task_instance)
else:
self.log.error("could not queue task %s", task_instance.key)
self.task_context_logger.error("could not queue task %s", task_instance.key, ti=task_instance)
vincbeck marked this conversation as resolved.
Show resolved Hide resolved

def queue_task_instance(
self,
Expand Down Expand Up @@ -284,8 +290,12 @@ def trigger_tasks(self, open_slots: int) -> None:
self.log.info("queued but still running; attempt=%s task=%s", attempt.total_tries, key)
continue
# Otherwise, we give up and remove the task from the queue.
self.log.error(
"could not queue task %s (still running after %d attempts)", key, attempt.total_tries
self.task_context_logger.error(
"Could not queue task %s as it is seen as still running after %d attempts (tried for %d seconds). It looks like it was killed externally. Look for external reasons why it has been killed (likely a bug or deployment issue).",
key,
attempt.total_tries,
RunningRetryAttemptType.MIN_SECONDS,
ti=ti,
)
del self.attempts[key]
del self.queued_tasks[key]
Expand Down
28 changes: 12 additions & 16 deletions airflow/providers/amazon/aws/executors/ecs/ecs_executor.py
vincbeck marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -320,10 +320,11 @@ def __handle_failed_task(self, task_arn: str, reason: str):
)
)
else:
self.log.error(
self.task_context_logger.error(
vincbeck marked this conversation as resolved.
Show resolved Hide resolved
"Airflow task %s has failed a maximum of %s times. Marking as failed",
task_key,
failure_count,
ti=task_key,
)
self.fail(task_key)
self.active_workers.pop_by_key(task_key)
Expand All @@ -347,7 +348,7 @@ def attempt_task_runs(self):
queue = ecs_task.queue
exec_config = ecs_task.executor_config
attempt_number = ecs_task.attempt_number
_failure_reasons = []
failure_reasons = []
if timezone.utcnow() < ecs_task.next_attempt_time:
self.pending_tasks.append(ecs_task)
continue
Expand All @@ -361,23 +362,21 @@ def attempt_task_runs(self):
if error_code in INVALID_CREDENTIALS_EXCEPTIONS:
self.pending_tasks.append(ecs_task)
raise
_failure_reasons.append(str(e))
failure_reasons.append(str(e))
except Exception as e:
# Failed to even get a response back from the Boto3 API or something else went
# wrong. For any possible failure we want to add the exception reasons to the
# failure list so that it is logged to the user and most importantly the task is
# added back to the pending list to be retried later.
_failure_reasons.append(str(e))
failure_reasons.append(str(e))
else:
# We got a response back, check if there were failures. If so, add them to the
# failures list so that it is logged to the user and most importantly the task
# is added back to the pending list to be retried later.
if run_task_response["failures"]:
_failure_reasons.extend([f["reason"] for f in run_task_response["failures"]])
failure_reasons.extend([f["reason"] for f in run_task_response["failures"]])

if _failure_reasons:
for reason in _failure_reasons:
failure_reasons[reason] += 1
if failure_reasons:
# Make sure the number of attempts does not exceed MAX_RUN_TASK_ATTEMPTS
if int(attempt_number) < int(self.__class__.MAX_RUN_TASK_ATTEMPTS):
ecs_task.attempt_number += 1
Expand All @@ -386,14 +385,16 @@ def attempt_task_runs(self):
)
self.pending_tasks.append(ecs_task)
else:
self.log.error(
vincbeck marked this conversation as resolved.
Show resolved Hide resolved
"ECS task %s has failed a maximum of %s times. Marking as failed",
self.task_context_logger.error(
vincbeck marked this conversation as resolved.
Show resolved Hide resolved
"ECS task %s has failed a maximum of %s times. Marking as failed. Reasons: %s",
task_key,
attempt_number,
", ".join(failure_reasons),
ti=task_key,
)
self.fail(task_key)
elif not run_task_response["tasks"]:
self.log.error("ECS RunTask Response: %s", run_task_response)
self.task_context_logger.error("ECS RunTask Response: %s", run_task_response, ti=task_key)
raise EcsExecutorException(
"No failures and no ECS tasks provided in response. This should never happen."
)
Expand All @@ -407,11 +408,6 @@ def attempt_task_runs(self):
# executor feature).
# TODO: remove when min airflow version >= 2.9.2
pass
if failure_reasons:
self.log.error(
"Pending ECS tasks failed to launch for the following reasons: %s. Retrying later.",
dict(failure_reasons),
)

def _run_task(
self, task_id: TaskInstanceKey, cmd: CommandType, queue: str, exec_config: ExecutorConfigType
Expand Down
39 changes: 6 additions & 33 deletions airflow/utils/log/file_task_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,20 +34,21 @@

from airflow.api_internal.internal_api_call import internal_api_call
from airflow.configuration import conf
from airflow.exceptions import AirflowException, RemovedInAirflow3Warning
from airflow.exceptions import RemovedInAirflow3Warning
from airflow.executors.executor_loader import ExecutorLoader
from airflow.utils.context import Context
from airflow.utils.helpers import parse_template_string, render_template_to_string
from airflow.utils.log.logging_mixin import SetContextPropagate
from airflow.utils.log.non_caching_file_handler import NonCachingFileHandler
from airflow.utils.log.task_context_logger import ensure_ti
from airflow.utils.session import provide_session
from airflow.utils.state import State, TaskInstanceState

if TYPE_CHECKING:
from pendulum import DateTime

from airflow.models import DagRun
from airflow.models.taskinstance import TaskInstance, TaskInstanceKey
from airflow.models.taskinstance import TaskInstance
from airflow.serialization.pydantic.dag_run import DagRunPydantic
from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic

Expand Down Expand Up @@ -140,32 +141,6 @@ def _interleave_logs(*logs):
last = v


def _ensure_ti(ti: TaskInstanceKey | TaskInstance | TaskInstancePydantic, session) -> TaskInstance:
vincbeck marked this conversation as resolved.
Show resolved Hide resolved
"""
Given TI | TIKey, return a TI object.

Will raise exception if no TI is found in the database.
"""
from airflow.models.taskinstance import TaskInstance

if isinstance(ti, TaskInstance):
return ti
val = (
session.query(TaskInstance)
.filter(
TaskInstance.task_id == ti.task_id,
TaskInstance.dag_id == ti.dag_id,
TaskInstance.run_id == ti.run_id,
TaskInstance.map_index == ti.map_index,
)
.one_or_none()
)
if not val:
raise AirflowException(f"Could not find TaskInstance for {ti}")
val.try_number = ti.try_number
return val


class FileTaskHandler(logging.Handler):
"""
FileTaskHandler is a python log handler that handles and reads task instance logs.
Expand Down Expand Up @@ -265,9 +240,9 @@ def close(self):
@internal_api_call
@provide_session
def _render_filename_db_access(
*, ti, try_number: int, session=None
*, ti: TaskInstance | TaskInstancePydantic, try_number: int, session=None
) -> tuple[DagRun | DagRunPydantic, TaskInstance | TaskInstancePydantic, str | None, str | None]:
ti = _ensure_ti(ti, session)
ti = ensure_ti(ti, session)
dag_run = ti.get_dagrun(session=session)
template = dag_run.get_log_template(session=session).filename
str_tpl, jinja_tpl = parse_template_string(template)
Expand All @@ -281,9 +256,7 @@ def _render_filename_db_access(
filename = render_template_to_string(jinja_tpl, context)
return dag_run, ti, str_tpl, filename

def _render_filename(
self, ti: TaskInstance | TaskInstanceKey | TaskInstancePydantic, try_number: int
) -> str:
def _render_filename(self, ti: TaskInstance | TaskInstancePydantic, try_number: int) -> str:
"""Return the worker log filename."""
dag_run, ti, str_tpl, filename = self._render_filename_db_access(ti=ti, try_number=try_number)
if filename:
Expand Down
55 changes: 44 additions & 11 deletions airflow/utils/log/task_context_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,44 @@
from typing import TYPE_CHECKING

from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.models.taskinstancekey import TaskInstanceKey
from airflow.utils.session import create_session

if TYPE_CHECKING:
from airflow.models.taskinstance import TaskInstance
from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic
from airflow.utils.log.file_task_handler import FileTaskHandler

logger = logging.getLogger(__name__)


def ensure_ti(ti: TaskInstanceKey | TaskInstance | TaskInstancePydantic, session) -> TaskInstance:
vincbeck marked this conversation as resolved.
Show resolved Hide resolved
"""
Given TI | TIKey, return a TI object.

Will raise exception if no TI is found in the database.
"""
from airflow.models.taskinstance import TaskInstance

if isinstance(ti, TaskInstance):
return ti
val = (
session.query(TaskInstance)
.filter(
TaskInstance.task_id == ti.task_id,
TaskInstance.dag_id == ti.dag_id,
TaskInstance.run_id == ti.run_id,
TaskInstance.map_index == ti.map_index,
)
.one_or_none()
)
if not val:
raise AirflowException(f"Could not find TaskInstance for {ti}")
val.try_number = ti.try_number
return val


class TaskContextLogger:
"""
Class for sending messages to task instance logs from outside task execution context.
Expand All @@ -57,7 +87,7 @@ def __init__(self, component_name: str, call_site_logger: Logger | None = None):
def _should_enable(self) -> bool:
if not conf.getboolean("logging", "enable_task_context_logger"):
return False
if not getattr(self.task_handler, "supports_task_context_logging", False):
if not self.task_handler:
dstandish marked this conversation as resolved.
Show resolved Hide resolved
logger.warning("Task handler does not support task context logging")
return False
logger.info("Task context logging is enabled")
Expand All @@ -78,13 +108,13 @@ def _get_task_handler() -> FileTaskHandler | None:
assert isinstance(h, FileTaskHandler)
return h

def _log(self, level: int, msg: str, *args, ti: TaskInstance):
def _log(self, level: int, msg: str, *args, ti: TaskInstance | TaskInstanceKey):
"""
Emit a log message to the task instance logs.

:param level: the log level
:param msg: the message to relay to task context log
:param ti: the task instance
:param ti: the task instance or the task instance key
"""
if self.call_site_logger and self.call_site_logger.isEnabledFor(level=level):
with suppress(Exception):
Expand All @@ -98,6 +128,9 @@ def _log(self, level: int, msg: str, *args, ti: TaskInstance):

task_handler = copy(self.task_handler)
try:
if isinstance(ti, TaskInstanceKey):
with create_session() as session:
ti = ensure_ti(ti, session)
task_handler.set_context(ti, identifier=self.component_name)
if hasattr(task_handler, "mark_end_on_close"):
task_handler.mark_end_on_close = False
Expand All @@ -109,7 +142,7 @@ def _log(self, level: int, msg: str, *args, ti: TaskInstance):
finally:
task_handler.close()

def critical(self, msg: str, *args, ti: TaskInstance):
def critical(self, msg: str, *args, ti: TaskInstance | TaskInstanceKey):
"""
Emit a log message with level CRITICAL to the task instance logs.

Expand All @@ -118,7 +151,7 @@ def critical(self, msg: str, *args, ti: TaskInstance):
"""
self._log(logging.CRITICAL, msg, *args, ti=ti)

def fatal(self, msg: str, *args, ti: TaskInstance):
def fatal(self, msg: str, *args, ti: TaskInstance | TaskInstanceKey):
"""
Emit a log message with level FATAL to the task instance logs.

Expand All @@ -127,7 +160,7 @@ def fatal(self, msg: str, *args, ti: TaskInstance):
"""
self._log(logging.FATAL, msg, *args, ti=ti)

def error(self, msg: str, *args, ti: TaskInstance):
def error(self, msg: str, *args, ti: TaskInstance | TaskInstanceKey):
"""
Emit a log message with level ERROR to the task instance logs.

Expand All @@ -136,7 +169,7 @@ def error(self, msg: str, *args, ti: TaskInstance):
"""
self._log(logging.ERROR, msg, *args, ti=ti)

def warn(self, msg: str, *args, ti: TaskInstance):
def warn(self, msg: str, *args, ti: TaskInstance | TaskInstanceKey):
"""
Emit a log message with level WARN to the task instance logs.

Expand All @@ -145,7 +178,7 @@ def warn(self, msg: str, *args, ti: TaskInstance):
"""
self._log(logging.WARNING, msg, *args, ti=ti)

def warning(self, msg: str, *args, ti: TaskInstance):
def warning(self, msg: str, *args, ti: TaskInstance | TaskInstanceKey):
"""
Emit a log message with level WARNING to the task instance logs.

Expand All @@ -154,7 +187,7 @@ def warning(self, msg: str, *args, ti: TaskInstance):
"""
self._log(logging.WARNING, msg, *args, ti=ti)

def info(self, msg: str, *args, ti: TaskInstance):
def info(self, msg: str, *args, ti: TaskInstance | TaskInstanceKey):
"""
Emit a log message with level INFO to the task instance logs.

Expand All @@ -163,7 +196,7 @@ def info(self, msg: str, *args, ti: TaskInstance):
"""
self._log(logging.INFO, msg, *args, ti=ti)

def debug(self, msg: str, *args, ti: TaskInstance):
def debug(self, msg: str, *args, ti: TaskInstance | TaskInstanceKey):
"""
Emit a log message with level DEBUG to the task instance logs.

Expand All @@ -172,7 +205,7 @@ def debug(self, msg: str, *args, ti: TaskInstance):
"""
self._log(logging.DEBUG, msg, *args, ti=ti)

def notset(self, msg: str, *args, ti: TaskInstance):
def notset(self, msg: str, *args, ti: TaskInstance | TaskInstanceKey):
"""
Emit a log message with level NOTSET to the task instance logs.

Expand Down
Loading