Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def get_log(
TaskInstance.dag_id == dag_id,
TaskInstance.run_id == dag_run_id,
TaskInstance.map_index == map_index,
TaskInstance.try_number == try_number,
)
.join(TaskInstance.dag_run)
.options(joinedload(TaskInstance.trigger).joinedload(Trigger.triggerer_job))
Expand Down
11 changes: 11 additions & 0 deletions airflow-core/src/airflow/models/taskinstancehistory.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
if TYPE_CHECKING:
from sqlalchemy.orm.session import Session

from airflow.models import DagRun
from airflow.models.taskinstance import TaskInstance


Expand Down Expand Up @@ -161,6 +162,11 @@ def __init__(
Index("idx_tih_dag_run", dag_id, run_id),
)

@property
def id(self) -> str:
"""Alias for primary key field to support TaskInstance."""
return self.task_instance_id

@staticmethod
@provide_session
def record_ti(ti: TaskInstance, session: Session = NEW_SESSION) -> None:
Expand All @@ -183,3 +189,8 @@ def record_ti(ti: TaskInstance, session: Session = NEW_SESSION) -> None:
ti.set_duration()
ti_history = TaskInstanceHistory(ti, state=ti_history_state)
session.add(ti_history)

@provide_session
def get_dagrun(self, session: Session = NEW_SESSION) -> DagRun:
"""Return the DagRun for this TaskInstanceHistory, matching TaskInstance."""
return self.dag_run
54 changes: 17 additions & 37 deletions airflow-core/src/airflow/utils/log/file_task_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
from pydantic import BaseModel, ConfigDict, ValidationError

from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.executors.executor_loader import ExecutorLoader
from airflow.utils.helpers import parse_template_string, render_template
from airflow.utils.log.logging_mixin import SetContextPropagate
Expand All @@ -45,7 +44,7 @@
if TYPE_CHECKING:
from airflow.executors.base_executor import BaseExecutor
from airflow.models.taskinstance import TaskInstance
from airflow.models.taskinstancekey import TaskInstanceKey
from airflow.models.taskinstancehistory import TaskInstanceHistory
from airflow.typing_compat import TypeAlias


Expand Down Expand Up @@ -180,32 +179,6 @@ def _interleave_logs(*logs: str | LogMessages) -> Iterable[StructuredLogMessage]
last = msg


def _ensure_ti(ti: TaskInstanceKey | TaskInstance, session) -> TaskInstance:
"""
Given TI | TIKey, return a TI object.

Will raise exception if no TI is found in the database.
"""
from airflow.models.taskinstance import TaskInstance

if isinstance(ti, TaskInstance):
return ti
val = (
session.query(TaskInstance)
.filter(
TaskInstance.task_id == ti.task_id,
TaskInstance.dag_id == ti.dag_id,
TaskInstance.run_id == ti.run_id,
TaskInstance.map_index == ti.map_index,
)
.one_or_none()
)
if not val:
raise AirflowException(f"Could not find TaskInstance for {ti}")
val.try_number = ti.try_number
return val


class FileTaskHandler(logging.Handler):
"""
FileTaskHandler is a python log handler that handles and reads task instance logs.
Expand Down Expand Up @@ -253,7 +226,9 @@ def __init__(
Some handlers emit "end of log" markers, and may not wish to do so when task defers.
"""

def set_context(self, ti: TaskInstance, *, identifier: str | None = None) -> None | SetContextPropagate:
def set_context(
self, ti: TaskInstance | TaskInstanceHistory, *, identifier: str | None = None
) -> None | SetContextPropagate:
"""
Provide task_instance context to airflow task handler.

Expand Down Expand Up @@ -309,9 +284,10 @@ def close(self):
self.handler.close()

@provide_session
def _render_filename(self, ti: TaskInstance, try_number: int, session=NEW_SESSION) -> str:
def _render_filename(
self, ti: TaskInstance | TaskInstanceHistory, try_number: int, session=NEW_SESSION
) -> str:
"""Return the worker log filename."""
ti = _ensure_ti(ti, session)
dag_run = ti.get_dagrun(session=session)

date = dag_run.logical_date or dag_run.run_after
Expand Down Expand Up @@ -344,8 +320,8 @@ def _render_filename(self, ti: TaskInstance, try_number: int, session=NEW_SESSIO
raise RuntimeError(f"Unable to render log filename for {ti}. This should never happen")

def _get_executor_get_task_log(
self, ti: TaskInstance
) -> Callable[[TaskInstance, int], tuple[list[str], list[str]]]:
self, ti: TaskInstance | TaskInstanceHistory
) -> Callable[[TaskInstance | TaskInstanceHistory, int], tuple[list[str], list[str]]]:
"""
Get the get_task_log method from executor of current task instance.

Expand All @@ -367,7 +343,7 @@ def _get_executor_get_task_log(

def _read(
self,
ti: TaskInstance,
ti: TaskInstance | TaskInstanceHistory,
try_number: int,
metadata: dict[str, Any] | None = None,
):
Expand Down Expand Up @@ -455,15 +431,19 @@ def _read(
return logs, {"end_of_log": end_of_log, "log_pos": log_pos}

@staticmethod
def _get_pod_namespace(ti: TaskInstance):
@staticmethod
def _get_pod_namespace(ti: TaskInstance | TaskInstanceHistory):
pod_override = ti.executor_config.get("pod_override")
namespace = None
with suppress(Exception):
namespace = pod_override.metadata.namespace
return namespace or conf.get("kubernetes_executor", "namespace")

def _get_log_retrieval_url(
self, ti: TaskInstance, log_relative_path: str, log_type: LogType | None = None
self,
ti: TaskInstance | TaskInstanceHistory,
log_relative_path: str,
log_type: LogType | None = None,
) -> tuple[str, str]:
"""Given TI, generate URL with which to fetch logs from service log server."""
if log_type == LogType.TRIGGER:
Expand All @@ -487,7 +467,7 @@ def _get_log_retrieval_url(

def read(
self,
task_instance: TaskInstance,
task_instance: TaskInstance | TaskInstanceHistory,
try_number: int | None = None,
metadata: dict[str, Any] | None = None,
) -> tuple[list[StructuredLogMessage] | str, dict[str, Any]]:
Expand Down
15 changes: 12 additions & 3 deletions airflow-core/src/airflow/utils/log/log_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from sqlalchemy.orm.session import Session

from airflow.models.taskinstance import TaskInstance
from airflow.models.taskinstancehistory import TaskInstanceHistory
from airflow.typing_compat import TypeAlias

LogMessages: TypeAlias = Union[list[StructuredLogMessage], str]
Expand All @@ -48,7 +49,10 @@ class TaskLogReader:
"""Number of empty loop iterations before stopping the stream"""

def read_log_chunks(
self, ti: TaskInstance, try_number: int | None, metadata
self,
ti: TaskInstance | TaskInstanceHistory,
try_number: int | None,
metadata: LogMetadata,
) -> tuple[LogMessages, LogMetadata]:
"""
Read chunks of Task Instance logs.
Expand All @@ -70,7 +74,12 @@ def read_log_chunks(
"""
return self.log_handler.read(ti, try_number, metadata=metadata)

def read_log_stream(self, ti: TaskInstance, try_number: int | None, metadata: dict) -> Iterator[str]:
def read_log_stream(
self,
ti: TaskInstance | TaskInstanceHistory,
try_number: int | None,
metadata: LogMetadata,
) -> Iterator[str]:
"""
Continuously read log to the end.

Expand Down Expand Up @@ -147,7 +156,7 @@ def supports_external_link(self) -> bool:
@provide_session
def render_log_filename(
self,
ti: TaskInstance,
ti: TaskInstance | TaskInstanceHistory,
try_number: int | None = None,
*,
session: Session = NEW_SESSION,
Expand Down
33 changes: 33 additions & 0 deletions airflow-core/tests/unit/utils/test_log_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from airflow.jobs.triggerer_job_runner import TriggererJobRunner
from airflow.models.dagrun import DagRun
from airflow.models.taskinstance import TaskInstance
from airflow.models.taskinstancehistory import TaskInstanceHistory
from airflow.models.trigger import Trigger
from airflow.providers.standard.operators.python import PythonOperator
from airflow.utils.log.file_task_handler import (
Expand Down Expand Up @@ -606,6 +607,38 @@ def test_jinja_rendering_catchup_false(self, create_log_template, create_task_in
rendered_filename = fth._render_filename(filename_rendering_ti, 42)
assert expected_filename == rendered_filename

def test_jinja_id_in_template_for_history(
self, create_log_template, create_task_instance, logical_date, session
):
"""Test that Jinja template using ti.id works for both TaskInstance and TaskInstanceHistory"""
create_log_template("{{ ti.id }}.log")
ti = create_task_instance(
dag_id="dag_history_test",
task_id="history_task",
run_type=DagRunType.SCHEDULED,
logical_date=DEFAULT_DATE,
catchup=True,
)
TaskInstanceHistory.record_ti(ti, session=session)
session.flush()
tih = (
session.query(TaskInstanceHistory)
.filter_by(
dag_id=ti.dag_id,
task_id=ti.task_id,
run_id=ti.run_id,
map_index=ti.map_index,
try_number=ti.try_number,
)
.one()
)
fth = FileTaskHandler("")
rendered_ti = fth._render_filename(ti, ti.try_number, session=session)
rendered_tih = fth._render_filename(tih, ti.try_number, session=session)
expected = f"{ti.id}.log"
assert rendered_ti == expected
assert rendered_tih == expected


class TestLogUrl:
def test_log_retrieval_valid(self, create_task_instance):
Expand Down